Here is how I’m starting the code
pt_estimator = PyTorch(
entry_point="ph_1_5_with_accelerator.py",
source_dir='source_dir_phi_1_5',
role=get_execution_role(),
framework_version="1.10.2",
py_version="py38",
instance_count=1,
instance_type="ml.g5.16xlarge",
distribution={
"pytorchddp": {
"enabled": True # I've also hashtagged distribution out
}
}
)
pt_estimator.fit()