When training a Seq2SeqTrainer
model with evaluate
and it looks something like:
mt_metrics = evaluate.combine(
["bleu", "chrf"]
)
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
outputs = mt_metrics.compute(predictions=predictions,
references=references)
return outputs
model = MarianMTModel.from_pretrained("Helsinki-NLP/opus-mt-en-de")
training_args = Seq2SeqTrainingArguments(
output_dir='./',
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
logging_steps=100,
save_steps=500,
eval_steps=1,
max_steps=1_000_000,
evaluation_strategy="steps",
predict_with_generate=True,
report_to=None,
)
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=valid_data.with_format("torch"),
eval_dataset=test_data.with_format("torch"),
compute_metrics=compute_metrics,
)
trainer.train()
The EvalPrediction.predictions
objects is exposed to compute_metrics
, it contains the label_ids
and the predictions
ids but it doesn’t contain the input_ids
, sometimes when training computing the metrics that requires the input_ids
:
mt_metrics = evaluate.combine(
["bleu", "chrf", "comet"]
)
def compute_metrics(pred, input_ids):
labels_ids = pred.label_ids
pred_ids = pred.predictions
source = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
predictions = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
references = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
outputs = mt_metrics.compute(predictions=predictions,
references=references, sources=sources)
return outputs
Is there someway to include the input_ids
in the EvalPrediction
object when using Seq2SeqTrainer
?
If there isn’t, could anyone help to point me to docs to rewrite to create my own custom Seq2SeqTrainer and EvalPrediction? Thank you in advance!