Bug in Summarization tutorial

Hello, I was reproducing the tutorial Summarization

The code seems to contain the same problem, that is discussed here Decoding error while using DataCollatorForSeq2Seq · Issue #24433 · huggingface/transformers · GitHub
Forgetting to replace -100-s in prediction labels will lead to error
OverflowError: out of range integral type conversion attempted
And in “compute metrics” function this replacement is done only with labels, not with predictions.

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

This is a very tricky bug and probably it would be great to change something in the library, so that people won’t have to do it every time manually. Cuz, as you can see it leads to this hard-to-debug problem.
It’s hard to debug because it throws an error only if the padding was used, which happens pretty randomly.
By default generation length is 20, which is the reason, why most of the time the notebook from the tutorial executes without the error.
But after increasing max_gen_len e.g. to 100 it fails much more often.

training_args = Seq2SeqTrainingArguments(
	...
    generation_max_length=100,
)
1 Like

Hey @Hacker1337
Did you find a solution for this? Having the same problem.

Yes, solution is very simple. As described in the github issue, you have to simply replace -100 with padding token in predictions as well as was done with labels.

Insert this before using predictions values.
predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)