T5 fp16 issue is fixed

We have just fixed the T5 fp16 issue for some of the T5 models!

(Announcing it here, since lots of users were facing this issue and T5 is one most widely used model in the library)

TL;DR:

Previously, there was an issue when using T5 models in fp16; it was producing nan loss and logits. Now on the master, this issue is fixed for the following T5 models and versions. Now you should be able to train and use these models for inference in fp16 and see a decent speed-up!

  • T5v1 : t5-small , t5-base , t5-large
  • T5v1_1 : google/t5-v1_1-small , google/t5-v1_1-base
  • MT5 : google/mt5-small , google/mt5-base

For those of you who are interested, here’s a description of what was causing nan loss and how it is fixed.

t5-small was the only T5 model that was working in fp16. The rest of the models produce nan loss/logits.

for all the models and versions (v1, v1.1, mT5), at some point, we get inf values in hidden_states after applying the final linear layer (wo) in T5DenseReluDense and T5DenseGatedGeluDense.

which results in nan values in T5LayerNorm.

Also for t5-large, t5-v1_1-base, t5-v1_1-large, there are inf values in the output of T5LayerSelfAttention and T5LayerCrossAttention, specifically where we add the attn output with the hidden_states

This happens during both training and inference, to reproduce

fix

To avoid inf values we could clamp the hidden_states to the max values for the current data type if there are inf in it. i.e

if torch.isinf(hidden_states).any():
    clamp_value = torch.finfo(hidden_states.dtype).max - 1000
    hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

we need to add this after self attn, cross-attn, and the feed-forward layer which is where the inf values occur. This works for both apex and amp

To verify this fix, trained t5-base, t5-v1_1-base and t5-v1_1-small on cnn/dm for 10k steps (1.11 epochs)
Here’s the training command, to run this navigate to examples/seq2seq dir, follow the instructions in the readme to download cnn_dm and dataset, and then run the following command

export M=google/t5-v1_1-base
export OUT_DIR=t5-v1_1-base-cnn-fp16
export DATA_DIR=cnn_dm

python finetune_trainer.py \
    --model_name_or_path $M \
    --data_dir $DATA_DIR \
    --output_dir $OUT_DIR --overwrite_output_dir \
    --max_steps=10000 \
    --gradient_accumulation_steps=8 \
    --learning_rate=1e-4 \
    --per_device_train_batch_size=4 \
    --n_val 500 \
    --max_target_length=56 --val_max_target_length=128 \
    --fp16 --fp16_backend apex \
    --do_train --do_eval --evaluation_strategy steps \
    --logging_steps=100 --logging_first_step --eval_steps=2500 --save_steps=2500 --save_total_limit=2 \
    --sortish_sampler \

for evaluation

python run_eval.py \
    t5-v1_1-base-cnn-fp16  cnn_dm/test.source hypothesis.txt \
    --reference_path cnn_dm/test.target \
    --score_path metrics.json \
    --device cuda:0 \
    --prefix summarize: \
    --bs 16 \
    --fp16 \

and got the following metrics (ROUGE2)

  1. for t5-base: 19.2804
  2. for t5-v1.1-base: 18.4316
    (note that the score for t5-base is more because it’s already pre-trained on cnn/dm)

To compare this, evaluated the pre-trained t5-base in both fp32 and fp16, which gave the following results

  1. fp16: 18.3681
  2. fp32: 18.394

So the results are close enough.

To verify the fix for t5-large, evaluated the pre-trained t5-large in fp32 and fp16 (use the same command above to evaluate t5-large) and got the following results

  1. fp16: 19.2734
  2. fp32: 19.2342

Surprisingly, rouge2 is slightly better in fp16.

So with the above fix, the following model types now work in fp16 (opt level 01), and give descent speed-up :slight_smile:

  • T5v1: t5-small, t5-base, t5-large
  • T5v1_1: google/t5-v1_1-small, google/t5-v1_1-base
  • MT5: google/mt5-small, google/mt5-base

One interesting observation,
For inference, the t5-base fine-tuned with fp16 and evaluated in fp32 is faster (~1.31x) than pre-trained t5-base evaluated in fp16. See this colab

7 Likes

Nice fix!
The speed discrepancy might be because of different length generations.

Very cool! Works well on T4 as well.

Any guesses why the inference time for fp16 doesn’t seem to be noticeably faster? Saw only a small difference in the shared colab, and seeing similar behavior locally…

Hey @aaronchavez

as explained in this issue Improve PyTorch examples for FP16 · Issue #9752 · huggingface/transformers · GitHub
To get the full speed-up of FP16 training, every tensor passed through the model should have all its dimensions be a multiple of 8.

Thanks for replying! I was wondering about inference in particular (via generate())…is it the same situation, there? I am seeing approximately the same times for generation with fp16 and fp32.

Hi

This is not fixed as far as I see, could you have a look here please:

I really need these models and I appreciate having a look into it.
thanks a lot

@valhalla @patrickvonplaten

Hi Dara,
Can you specify which model you’re trying to use?

Hi there
I use mt5-small, more info provided in the link to the issues me and another users opened thanks

@valhalla I tried “finetune_trainer.py” you mentioned command.

I got this error.

07/23/2021 06:26:38 - INFO - __main__ - *** Train ***
/usr/local/lib/python3.7/dist-packages/transformers/trainer.py:1026: FutureWarning: `model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` instead.
  FutureWarning,
Traceback (most recent call last):
  File "/content/gdrive/MyDrive/Colab Notebooks/generation/transformers/examples/legacy/seq2seq/finetune_trainer.py", line 367, in <module>
    main()
  File "/content/gdrive/MyDrive/Colab Notebooks/generation/transformers/examples/legacy/seq2seq/finetune_trainer.py", line 305, in main
    model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
  File "/usr/local/lib/python3.7/dist-packages/transformers/trainer.py", line 1138, in train
    self.create_optimizer_and_scheduler(num_training_steps=max_steps)
  File "/content/gdrive/MyDrive/Colab Notebooks/generation/transformers/examples/legacy/seq2seq/seq2seq_trainer.py", line 118, in create_optimizer_and_scheduler
    if self.sharded_dpp:
AttributeError: 'Seq2SeqTrainer' object has no attribute 'sharded_dpp'

I have no idea how to fix this issue. How can I fix this?

Hi @valhalla

The fix seems not to be applied on the MT5 model:

  1. Nan losses occur when fine-tuning the model with the latest (stable) version of transformers.
  2. MT5 works well with fp16.

Do you see the same behavior?

1 Like

I had the same issue and I changed it to fp16=False and it worked. Although the loss became a non-nan value, it caused the Out Of Memory error frequently. Impossible to train more. I tested many things and finally I tried this:
fp16_full_eval=True instead of fp16l=True. It solved the loss problem but at the end ROUGE metrics became all zero. How can it compute the loss but not the metrics?

Are the changes already merged?
I am unable to fine-tune mT5 with fp16.

@ valhalla

So where is this fix now? I do not get where I should place code? Is it in anew version of transformers?

I am getting same issue with google/flan-t5-small model with fp16=True. Is the error reported for this model too?

2 Likes

I am getting the same issue with Salesforce/blip2-flan-t5-xl which uses google/flan-t5-xl with amp.
I am using transformers==4.29.2.
I got inf values in hidden_states after applying the final linear layer (wo ) in T5DenseGatedActDense.
I found this by using PyCharm debug mode, as shown in the screenshot below:


By the way, I didn’t find the fixing code which clamps the hidden_states to the max values described in fix in the T5DenseGatedActDense class.

I am using transformers==4.29.2.
I may have found a logical flaw in T5Block as shown in below:

Blockquote

Apply Feed Forward layer

hidden_states = self.layer-1

clamp inf values to enable fp16 training

if hidden_states.dtype == torch.float16:
clamp_value = torch.where(
torch.isinf(hidden_states).any(),
torch.finfo(hidden_states.dtype).max - 1000,
torch.finfo(hidden_states.dtype).max,
)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)

Blockquote
Apparently the fixing code which clamps the hidden_states to the max values only works when hidden_states is fp16. It may works fine when the whole model is fp16. But I am trying to do mixed precision training using torch.amp, so the hidden_states returned by T5LayerFF is fp32 because the last operation in T5LayerFF is the adding op, a.k.a the residual connection, which can’t autocast to float16 according to https://pytorch.org/docs/stable/amp.html#cuda-ops-that-can-autocast-to-float16.
The screenshot below also proved this theory:

2 Likes

3 year later, it seems t5-large still doesn’t support fp16. Anyone found a solution?

You should use bf16 instead of fp16 for T5 models