Where is SageMaker Distributed configured in HF Trainer?

I see that the HF Trainer run_qa script is compatible with SageMaker Distributed Data Parallel, but I don’t see where is it configured?

In particular, I can see in the training_args that smdist gets imported and configured, but where is the model wrapped with smdist DDP?

According to the smdist doc the below snippet is a required step ; I’d like to understand where it’s done with HF Trainer

from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
model = DDP(Net().to(device))
1 Like

Hey @OlivierCR,

both the SageMaker Distributed Data-Parallel and the Model-Parallel library are directly integrated into the Trainer API, which uses and initializes both libraries automatically.
For SMD:

  1. The library is first imported with an alias for the default PyTorch DDP library here
  2. and then wrapps the model here

P.S. The _wrap_model() function also handles SMP

1 Like

wow so amazing, good job guys