What does "generate_with_predict=True" actually do?

Hello, I am currently trying to finetuning T5 for summarization task using PyTorch/XLA, and I want to know what is the purpose of generate_with_predict. I saw the documentation and know its supposed to be used with ROUGE/BLEU. But, I am confused what it actually does. If I give generate_with_predict=True, then, will the output be decoded on its own and the metric will be directly calculated if I pass it like given below:

from datasets import load_metric

rouge_metric = load_metric("rouge")
trainer = Seq2SeqTrainer(

Or, Do I have to wrap rouge inside another function like below, and then pass it:

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    return {k: round(v, 4) for k, v in result.items()}

PS: While using wrapper method, colab/kaggle crashes due to RAM usage exceeding 100%. Even when I use TPU.

Look at the example notebook or the example script for summarization. You will see you have to pass along the latter.

1 Like

Hi @sgugger, I understood the purpose of predict_with_generate from the example script. Normally, the forward pass of the model returns loss and logits, but we need tokens for the ROUGE/BLEU, where generate() comes into picture and predict_with_generate handles that. But, then, I am not able to understand the cause for my RAM usage exceeding 100%? I tried caching the datasets, defined them globally, used xmp.SerialExecutor(). But, with every epoch RAM usages keeps increasing? I am not able to finetune t5-small for more than 4 epochs if I pass compute_metrics.

@sgugger Is it because of the generated tokens at every validation epoch?

Edit: link to kernel