The metrics available in evaluate
has been extremely helpful but it has been a little verbose to keep pasting the compute_metrics
function to wrap around the metric when we use the Trainer
object.
E.g. I think thereās quite a lot of people that have been pasting this when training a span classifier model:
metric = evaluate.load("seqeval")
def compute_metrics(p):
predictions, labels = p
predictions = predictions.argmax(axis=2)
# Remove ignored index (special tokens)
true_predictions = [
[label_list[p] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
true_labels = [
[label_list[l] for (p, l) in zip(prediction, label) if l != -100]
for prediction, label in zip(predictions, labels)
]
results = metric.compute(predictions=true_predictions, references=true_labels)
return {
"precision": results["overall_precision"],
"recall": results["overall_recall"],
"f1": results["overall_f1"],
"accuracy": results["overall_accuracy"],
}
# Initialize our Trainer
trainer = Trainer(
model=model,
train_dataset=ds_train,
eval_dataset=ds_eval,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
Would it be a good idea to add a default compute_metrics
function for popular evaluate.metric()
?
The ideal usage would be something like:
metric = evaluate.load("seqeval")
trainer = Trainer(
model=model,
train_dataset=ds_train,
eval_dataset=ds_eval,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=metric.compute_metrics,
)
Or maybe something like:
# Kind of allow some re-namings of the expected output metric keys.
metric = evaluate.load("seqeval", rename_outputs={"overall_precision": "precision", ...})
trainer = Trainer(
model=model,
train_dataset=ds_train,
eval_dataset=ds_eval,
data_collator=data_collator,
tokenizer=tokenizer,
compute_metrics=metric.compute_metrics,
)
I guess thereās a lot of nuance in how users might want to modify the compute_metrics
but some of the task metrics are very common and perhaps cutting away the need to copy+paste compute_metrics
can help lower the barrier further when using Trainer.