Fine-Tuning / Pre-Training Tips

Copy-pasting some internal discussions here as it might be interesting for everyone. Feel free to comment / ask questions!


I only have experience with training models <= 2B parameters and only in PyTorch / Flax and only with using Accelerate, Examples/Trainer on TPU & GPU.

  • TPU or GPU? For fine-tuning models <= 2B if someone doesn’t have experience with TPUs, it doesn’t really make sense to jump into TPUs (might be different for TF) because the models are fine-tuned pretty quickly usually (less than a day), so the actual cost that’s saved per training run is negligible. Also, it’s just much easier to quickly experiment with PyTorch then with Flax (more docs in PyTorch, much better support for GPU)
  • Trainer or accelerate? Then I usually use the Trainer for fine-tuning because it works out of the box for pretty much all use-cases. I haven’t used accelerate yet for fine-tuning only, but we also have lots of working examples for it now - so I think both Trainer and accelerate make sense here. Trainer is better atm IMO because it has great support to directly upload checkpoints to the Hub, automatically creates the README, etc…
  • Optimizations:
  • fp16? bfloat16? Now for more or less all models that come from Facebook/Fairseq or Microsoft (BART, Wav2Vec2, Roberta, WavLM, DialoGPT) I would do training in fp16 because these models have all been pretrained in PyTorch usually and also usually in fp16 . Lots of time and memory can be saved for more or less the same performance. I’ve actually never not trained in fp16 for Facebook or MS models. For Google models, especially T5, they have mostly been pre-trained in bfloat16 on TPU which means that fine-tuning in fp16 might break! This is the case for all mid- to larger-sized T5 models. Then I usually just use full fp32 precision because bfloat16 is only very recently support in native PyTorch I think. Guess @Stas knows more here.
  • Other optimizations? PyTorch is not compiled during training usually so it works well with dynamic batches. This means I always use dynamic padding when training in PyTorch by setting group_by_length to True. Group by length, groups inputs according to their length and to make sure batches contain input samples of more or less the same length. This has a couple of advantages over not using it:
  • More efficient since less padding tokens are processed
  • Less padding can also mean for some models (like Wav2Vec2), less instability because less tokens are masked out with 10,000.0
  • @Sylvain Gugger added a very nice hack when he added group_by_length which puts the largest batch first. This means, you’ll know directly whether your model with OOM during training and don’t have any bad surprises after 2 hours of training
  • I also always use gradient_checkpointing because it saves a lot of memory during training, especially when the model has many layers **If the model still doesn’t fit into memory, first I replace the torch's native Adam optimizer (it’s called by default) with the 8bit adam optimizer which did in some cases save a lot of memory for me. If this also doesn’t work I make use of gradient_accumulation_steps


I only have experience with training models <= 2B parameters and only in PyTorch / Flax and only with using Accelerate, Examples/Trainer on TPU & GPU.

  • TPU or GPU? Here I do recommend using TPU if the person is accustomed to TPUs. It can lead to significant speed-ups and cost savings, e.g.: . PyTorch/XLA on TPU can work very well, but doesn’t have great support. Flax TPU is more robust on faster than PyTorch/XLA in my experiments, but requires someone to know JAX/Flax. For more model-specific libaries, the TF-T5 an Jax-T5 libraries are great to pretrain T5 and it’s easy to convert models from those libraries to Transformers.
  • If I use PyTorch, then I use accelerate rather than the Trainer because I want to have a more customized training loop (log special metrics, …)
  • For pre-training the same optimization tips apply then the ones written above, except for group_by_length which is often not necessary because one can “blockify” the data so that no padding at all is needed

Helpful additional links