When training a Seq2SeqTrainer model with evaluate and it looks something like:
mt_metrics = evaluate.combine(
["bleu", "chrf"]
)
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
outputs = mt_metrics.compute(predictions=predictions,
references=references)
return outputs
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
training_args = Seq2SeqTrainingArguments(
output_dir='./',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
logging_steps=100,
save_steps=500,
eval_steps=1,
max_steps=1_000_000,
evaluation_strategy="steps",
predict_with_generate=True,
report_to=None,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=valid_data.with_format("torch"),
eval_dataset=test_data.with_format("torch"),
compute_metrics=compute_metrics,
)
trainer.train()
The EvalPrediction.predictions objects is exposed to compute_metrics, it contains the label_ids and the predictions ids but it doesn’t contain the input_ids, sometimes when training computing the metrics that requires the input_ids:
mt_metrics = evaluate.combine(
["bleu", "chrf", "comet"]
)
def compute_metrics(pred, input_ids):
labels_ids = pred.label_ids
pred_ids = pred.predictions
source = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
outputs = mt_metrics.compute(predictions=predictions,
references=references, sources=sources)
return outputs
Is there someway to include the input_ids in the EvalPrediction object when using Seq2SeqTrainer?
If there isn’t, could anyone help to point me to docs to rewrite to create my own custom Seq2SeqTrainer and EvalPrediction? Thank you in advance!