Training t5-based seq to seq suddenly reaches loss of `nan` and starts predicting only `<pad>`

I’m trying to train a t5 based LM head model (mrm8488/t5-base-finetuned-wikiSQL) using my custom data to turn text into SQL (based roughly on the SPIDER dataset).

The current training loop I have is something like this:

parameters = self.model.parameters()
optimizer = AdamW(parameters, lr=1e-5) # imported from `transformers`
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=5,
    num_training_steps=len(data) * nr_epochs,
)

for epoch in range(nr_epochs):
    for batch in data_loader:
        optimizer.zero_grad()
        predictions = model(**batch)
        loss = predictions[0]
        loss.backward()
        optimizer.step()
        scheduler.step()

Note: Simplified, I don’t show early stopping, datasource creation, dl creation, some custom scheduling logic, etc. But none of that should be relevant.

Pretty standard, the batch dictionary contains: input_ids, attention_mask, labels, decoder_attention_mask. I get the inputs_ids and attention_mask from tokenizing my input text, I get the labels and dedocer_attention_mask from tokenizing my target text (with the same tokenizer).

I tried also passing decoder_input_ids (using the same values I used for labels) but it results in a CUDA error (when using GPU) or a blas error (when using CPU). I tried deepcopying the tensor in case it was an issue of both this and labels pointing to the same object, nothing changes

My main question here is:

Why would this result in the yielded loss suddenly becoming nan and the model, if .backwards is called on that, suddenly start to predict everything as <pad> ?

Is it just that <pad> is what the tokenizer decodes if the middle predicts “gibberish” (i.e. nan, inf or a very high or low number that’s not associated with any char/seq by the tokenizer)

Furthermore, usually, losses seem to become nan after they start getting higher and higher, but in this case, the model seems to be improving until at one point a nan drops out of nowhere.

My other questions, to hopefully help address this, are:

  • Is the decoder_attention_mask actually the output_attention_mask ? The model seems to perform much better when I add it and I get it from tokenizing the target text (and it seems to overlap with the padding therein) … but, my impression was that the “decoder” here was the generator of embedding and that seq2seq models have an additional LM head. Am I just getting my terminology wrong? Is the argument just named poorly?
  • Is there any relevance to passing decoder_input_ids ? Should these just be equivalent to the labels (given that, see above, the “decoder” here seems to be referring to the LM head)? Should I consider passing them instead of passing labels? Why would I get cuda/blas related crashes when I do pass them?
  • My current approach is to just “ignore” a loss of nan, i.e. clear the gradient, don’t do backdrop, and keep moving. Is there a better alternative? Is the loss going to nan unexpected and maybe a sign I should look for and remove a “faulty” datapoint from the batch?

I get this is an unideal way to be training, but I couldn’t get the Seq2Seq trainer working (made a question regarding that here: Extremely confusing or non-existent documentation about the Seq2Seq trainer)

I also cross-posted this on stack overflow, in case anyone is helped by that: python - How to avoid huggingface t5-based seq to seq suddenly reaching a loss of `nan` and start predicting only `<pad>`? - Stack Overflow

Hi @George3d6,
did you manage to overcome your problem?

I’m currently working on the same fine-tuning task (t5-base with spider+custom data) and am surprised by the bad performance. After a few epochs of training, the training loss drops (validation a little slower) below 0.5, but the exact match accuracy is very low. Model then goes into overfitting. I just wanted to ask for some insights from you - how did your model perform?

From the stack overflow post (this solved the issue for me on flan-t5-xxl):

I had the same problem, but instead to use fp16=True, I used fp16_full_eval=True. This work for me, I hope it helps!