T5 on TPU implementation

I am trying to use the implementation of Suraj of T5 on TPU

I’ve worked through some issues with it where deprecations have com into play.

I am invoking this code:
xmp.spawn(_mp_fn, args=(), nprocs=8, start_method=‘fork’)

and it seems to executing but with some errors like the following

File “”, line 44, in call
input_ids = torch.stack([example[‘input_ids’] for example in batch])
RuntimeError: stack expects each tensor to be equal size, but got [375] at entry 0 and [551] at entry 1

but not sure if it will actually complete successfully? I have already let it run for many hours - but not sure if it will complete or not? with the next cells, the author specifically mentions that it takes 40 minutes or so…

I am conscious that the notebook is just over 2 years old. Any help or thoughts would be appreciated. I am on a colab pro account.