Mmmm, I don’t see the call to the spawn function, so yes, you’re probably training on CPU. Normally, you are supposed to called train
through
import torch_xla.distributed.xla_multiprocessing as xmp
xmp.spawn(train, args=potential_args, nprocs=num_tpu_cores)