Should we optimize the logic for enabling TorchXLA in a GPU environment

I followed this blog post (Large Scale Training of Hugging Face Transformers on TPUs With PyTorch/XLA FSDP | PyTorch) to run the fsdp job.