At a first glance, you have loss.item()
in your training loop, which you should absolutely avoid on TPUs (it’s a big slowdown). You should use loss.detach()
to accumulate your losses on the TPU then only do the .item()
at the very end of your epoch.
1 Like