You are not padding your inputs and targets to a fixed size in this example, but dynamically padding them to the longest input/target in each batch. This cause the TPU to recompile at each step, so it’s normal you see a very long training time compared to GPUs.
To properly train on TPU, you need to apply fixed padding in tokenize_and_align_labels
to a given length of your choice, and pad the labels to that same length.