Tips for training LongT5

TL;DR: Use --gradient_checkpointing when training LongT5 with long documents if encountering memory issues.

While experimenting with LongT5 with large source document (max_source_length=16384), I was running into memory issues. For reference, I am using 8xA100 GPUs (with 40Gb memory). I did not manage to train the “large” version of the model on max_source_length>4000 (even using deepspeed did not help much). I then found that using the trainer hyperparameter --gradient_checkpointing I was able to run LongT5 large with max_source_length=16384. Deepspeed (stage2 and 3) did not seem to help on top of that in terms of training speed despite allowing for slightly greater batch sizes (note I am not an expert with deepspeed so I was mostly used “auto” settings to configure deepspeed).