Trainer.evaluate() with text generation

Hi everyone, I’m fine-tuning XLNet for generation. For training, I’ve edited the permutation_mask to predict the target sequence one word at a time. I’m evaluating my trained model and am trying to decide between trainer.evaluate() and model.generate(). Running the same input/model with both methods yields different predicted tokens. Is it correct that trainer.evaluate() is not set up for sequential generation? I’ll switch my evaluation code to use model.generate() if that’s the case. Thanks for the help!

2 Likes

Hi, I encountered a similar problem when trying to use EncoderDecoderModel for seq2seq tasks. It seems like Trainer does not support text-generation tasks for now, as their website https://huggingface.co/transformers/examples.html shows.

There’s a PR for that, you can try to use it.

1 Like

Is there any update regarding this topic?
I would like to train a VisionEncoderDecoderModel for image captioning and measure the BLEU metrics during evaluation. The EvalPrediction object I get in compute_metrics just contains the logits, not the generated texts or tokens (i.e. the result of a beam search). I would assume that the computation of metrics on the result of generate is not uncommon.

The PR mentioned in this thread seems to be stale and there have been quite some changes to Trainer since it was proposed.

Hi @cgawron, you can take a look at my TrOCR notebooks here: Transformers-Tutorials/TrOCR at master · NielsRogge/Transformers-Tutorials · GitHub.

They include several example notebooks regarding fine-tuning TrOCR (which is an instance of VisionEncoderDecoderModel). I have a notebook regarding using the Seq2SeqTrainer, but also using native PyTorch. In both cases, I illustrate how to compute metrics using generate (in the notebooks I use CER, but you can easily replace it with something like ROUGE or BLEU).

1 Like

Thank you, @nielsr, for providing these examples!

Just for reference for other readers:
The Seq2SeqTrainingArguments now contain a flag predict_with_generate exactly for this purpose.