Gradient overflow when fine-tune t5 on CNN/DM dataset

I was trying to fine-tune t5 on CNN/DM dataset for summarization task.

I use the data based on README file in examples/seq2seq:

wget https://s3.amazonaws.com/datasets.huggingface.co/summarization/cnn_dm.tgz
tar -xzvf cnn_dm.tgz

I also successfully fine-tuned sshleifer/distilbart-cnn-12-6 on this dataset. But when I try to do it using t5-base, I receive the following error:

Epoch 1:   0%|                                                                                                                              | 2/37154 [00:07<40:46:19,  3.95s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 8192.0
Epoch 1:   0%|                                                                                                                              | 3/37154 [00:08<27:57:13,  2.71s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 4096.0
Epoch 1:   0%|                                                                                                                              | 4/37154 [00:08<21:32:17,  2.09s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 2048.0
Epoch 1:   0%|                                                                                                                              | 5/37154 [00:08<17:41:05,  1.71s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 1024.0
Epoch 1:   0%|                                                                                                                              | 6/37154 [00:08<15:15:58,  1.48s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 512.0
Epoch 1:   0%|                                                                                                                              | 7/37154 [00:09<13:24:44,  1.30s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 256.0
Epoch 1:   0%|                                                                                                                              | 8/37154 [00:09<12:01:25,  1.17s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 128.0
Epoch 1:   0%|                                                                                                                              | 9/37154 [00:09<10:56:27,  1.06s/it, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 64.0
Epoch 1:   0%|                                                                                                                             | 10/37154 [00:09<10:04:29,  1.02it/s, loss=nan, v_num=13]Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 32.0
<...>
Gradient overflow.  Skipping step, loss scaler 0 reducing loss scale to 0.0
Epoch 1:   3%|███▋                                                                                                                        | 1091/37154 [04:13<2:19:28,  4.31it/s, loss=nan, v_num=13]Traceback (most recent call last):
  File "finetune.py", line 409, in <module>
    main(args)
  File "finetune.py", line 383, in main
    logger=logger,
  File "/data/User/v3/bart/lightning_base.py", line 303, in generic_train
    trainer.fit(model)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1003, in fit
    results = self.single_gpu_train(model)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/distrib_parts.py", line 186, in single_gpu_train
    results = self.run_pretrain_routine(model)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1213, in run_pretrain_routine
    self.train()
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 370, in train
    self.run_training_epoch()
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 452, in run_training_epoch
    batch_output = self.run_training_batch(batch, batch_idx)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 632, in run_training_batch
    self.hiddens
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/pytorch_lightning/trainer/training_loop.py", line 822, in optimizer_closure
    error = context.__exit__(a, b, c)
  File "/root/anaconda3/lib/python3.7/contextlib.py", line 119, in __exit__
    next(self.gen)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/apex/amp/handle.py", line 123, in scale_loss
    optimizer._post_amp_backward(loss_scaler)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/apex/amp/_process_optimizer.py", line 190, in post_backward_with_master_weights
    models_are_masters=False)
  File "/data/User/v3/bart/venv/lib/python3.7/site-packages/apex/amp/scaler.py", line 117, in unscale
    1./scale)
ZeroDivisionError: float division by zero

My code for fine-tuning is modified based on examples/seq2seq:

./finetune.sh \
    --data_dir $DATA_DIR \
    --train_batch_size=8 \
    --eval_batch_size=8 \
    --output_dir=$OUTPUT_DIR \
    --num_train_epochs 5 \
    --model_name_or_path t5-base

Can anyone provide some suggestions? Thank you!

Looks like you are using fp16, currently there are few bugs with fp16 for T5 and I think those are not fixed, so try turning off fp16. Pinging @sshleifer for more info

1 Like

Yes this is a known issue.

1 Like

thank you!

hi, is there any plan to work on this? it’s my only hope for being able to finetune t5-large!

Hi @melody-ju,
T5 fine-tuning works well without fp16 and if you want to fine-tune t5-large but having memory issues then you can freeze the token embedings using the --freeze-embeds argument with finetune.py script.