T5 pre-training is now supported in JAX/FLAX. You can check out the example script here: transformers/examples/flax/language-modeling at master 路 huggingface/transformers 路 GitHub. It actually includes 2 scripts:
- t5_tokenizer_model.py, to train a T5 tokenizer (i.e. SentencePiece) from scratch.
- run_t5_mlm_flax.py, to pre-train T5. It鈥檚 suited to run on TPUs (for which you can obtain access for free by applying to Google鈥檚 TFRC program).
@patrickvonplaten also demonstrates how to run the script in this video (starts around 13:35).
This script was developed for the JAX/FLAX community event. It would be really cool if someone contributes the PyTorch version of it. It would mean translating the script from FLAX to PyTorch, which is probably straightforward.