T5 fp16 issue is fixed

We have just fixed the T5 fp16 issue for some of the T5 models!

(Announcing it here, since lots of users were facing this issue and T5 is one most widely used model in the library)


Previously, there was an issue when using T5 models in fp16; it was producing nan loss and logits. Now on the master, this issue is fixed for the following T5 models and versions. Now you should be able to train and use these models for inference in fp16 and see a decent speed-up!

  • T5v1 : t5-small , t5-base , t5-large
  • T5v1_1 : google/t5-v1_1-small , google/t5-v1_1-base
  • MT5 : google/mt5-small , google/mt5-base

For those of you who are interested, here’s a description of what was causing nan loss and how it is fixed.

t5-small was the only T5 model that was working in fp16. The rest of the models produce nan loss/logits.

for all the models and versions (v1, v1.1, mT5), at some point, we get inf values in hidden_states after applying the final linear layer (wo) in T5DenseReluDense and T5DenseGatedGeluDense.

which results in nan values in T5LayerNorm.

Also for t5-large, t5-v1_1-base, t5-v1_1-large, there are inf values in the output of T5LayerSelfAttention and T5LayerCrossAttention, specifically where we add the attn output with the hidden_states

This happens during both training and inference, to reproduce


To avoid inf values we could clamp the hidden_states to the max values for the current data type if there are inf in it. i.e

if torch.isinf(hidden_states).any():
    clamp_value = torch.finfo(hidden_states.dtype).max - 1000
    hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

we need to add this after self attn, cross-attn, and the feed-forward layer which is where the inf values occur. This works for both apex and amp

To verify this fix, trained t5-base, t5-v1_1-base and t5-v1_1-small on cnn/dm for 10k steps (1.11 epochs)
Here’s the training command, to run this navigate to examples/seq2seq dir, follow the instructions in the readme to download cnn_dm and dataset, and then run the following command

export M=google/t5-v1_1-base
export OUT_DIR=t5-v1_1-base-cnn-fp16
export DATA_DIR=cnn_dm

python finetune_trainer.py \
    --model_name_or_path $M \
    --data_dir $DATA_DIR \
    --output_dir $OUT_DIR --overwrite_output_dir \
    --max_steps=10000 \
    --gradient_accumulation_steps=8 \
    --learning_rate=1e-4 \
    --per_device_train_batch_size=4 \
    --n_val 500 \
    --max_target_length=56 --val_max_target_length=128 \
    --fp16 --fp16_backend apex \
    --do_train --do_eval --evaluation_strategy steps \
    --logging_steps=100 --logging_first_step --eval_steps=2500 --save_steps=2500 --save_total_limit=2 \
    --sortish_sampler \

for evaluation

python run_eval.py \
    t5-v1_1-base-cnn-fp16  cnn_dm/test.source hypothesis.txt \
    --reference_path cnn_dm/test.target \
    --score_path metrics.json \
    --device cuda:0 \
    --prefix summarize: \
    --bs 16 \
    --fp16 \

and got the following metrics (ROUGE2)

  1. for t5-base: 19.2804
  2. for t5-v1.1-base: 18.4316
    (note that the score for t5-base is more because it’s already pre-trained on cnn/dm)

To compare this, evaluated the pre-trained t5-base in both fp32 and fp16, which gave the following results

  1. fp16: 18.3681
  2. fp32: 18.394

So the results are close enough.

To verify the fix for t5-large, evaluated the pre-trained t5-large in fp32 and fp16 (use the same command above to evaluate t5-large) and got the following results

  1. fp16: 19.2734
  2. fp32: 19.2342

Surprisingly, rouge2 is slightly better in fp16.

So with the above fix, the following model types now work in fp16 (opt level 01), and give descent speed-up :slight_smile:

  • T5v1: t5-small, t5-base, t5-large
  • T5v1_1: google/t5-v1_1-small, google/t5-v1_1-base
  • MT5: google/mt5-small, google/mt5-base

One interesting observation,
For inference, the t5-base fine-tuned with fp16 and evaluated in fp32 is faster (~1.31x) than pre-trained t5-base evaluated in fp16. See this colab


Nice fix!
The speed discrepancy might be because of different length generations.