Thanks. I did as you suggested, but the training loop is still making very slow progress.
OLD:
epoch_loss = 0.0
for i, batch in enumerate(dl):
loss = loss_fn(yhat, y)
loss.backward()
epoch_loss += loss.item()
return epoch_loss/len(dl)
NEW:
epoch_loss = 0.0
for i, batch in enumerate(dl):
loss = loss_fn(yhat, y)
loss.backward()
epoch_loss += loss.detach() # <-- NEW
return epoch_loss.item()/len(dl) # <-- NEW
One batch is still taking a long time to complete. I suspect it’s running on the CPU rather than the TPU. However, I think I followed all XLA setup correctly. If this issue is out of Transformers’ domain, I’ll go bug the XLA folks.