Should we optimize the logic for enabling TorchXLA in a GPU environment

Recently, we have been trying to accelerate Transformers model training in a GPU environment using TorchXLA within the Accelerate library. However, we encountered some issues:

1.To run a training job without TorchXLA which is a native torch training job, users should uninstall the TorchXLA from their environment.

It would be more convenient for users to be able to run both native torch training jobs and TorchXLA jobs within the same environment, without requiring them to need uninstall or reinstall TorchXLA. Maybe we should reconsider the logic of the is_tpu_available function.

2.Use pytorch’s AMP instead of the XLA_DOWNCAST_BF16 options

When enabling bf16, the XLA_DOWNCAST_BF16 option will be set. In a GPU environment, it is recommended to use pytorch native AMP autocast rather than setting XLA_USE_BF16 or XLA_DOWNCAST_BF16 options. Additionally, use the GradScaler and syncfree Adamw Optimizer from TorchXLA.

3.When running fsdp job using TorchXLA as the backend, I encountered the eror“Object of type dtype is not JSON serializable”

The reason for the error is that TorchXLA FSDP saves torch.bfloat16 in the arguments, which cannot be serialized as JSON. https://github.com/huggingface/transformers/blob/d7cb5e138ec1ccc848a554574b1a89f0dfaf0e90/src/transformers/training_args.py#L1576

If TensorBoard is enabled, the error will occur in this line. https://github.com/huggingface/transformers/blob/d7cb5e138ec1ccc848a554574b1a89f0dfaf0e90/src/transformers/integrations/integration_utils.py#L627

After fixing these issues, we achieved a good speedup performance on GPU.

2 Likes

I followed this blog post (Large Scale Training of Hugging Face Transformers on TPUs With PyTorch/XLA FSDP | PyTorch) to run the fsdp job.

PR’s for all of this would be great :slight_smile: We have not done much to optimize XLA GPUs (because we haven’t really tested them) so any improvements to what we have currently would be great!

(If not, open an issue and we can know to flag and work on it)

Thanks for you rely :grinning: .
Okay, I will propose a PR later.