I see that the HF Trainer run_qa script is compatible with SageMaker Distributed Data Parallel, but I don’t see where is it configured?
In particular, I can see in the training_args that smdist gets imported and configured, but where is the model wrapped with smdist DDP?
According to the smdist doc the below snippet is a required step ; I’d like to understand where it’s done with HF Trainer
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
model = DDP(Net().to(device))