Hi, I am fine-tuning a classification model and would like to log accuracy, precision, recall and F1 using Trainer API.
While I am using metric = load_metric("glue", "mrpc")
it logs accuracy and F1, but when I am using metric = load_metric("precision", "recall", "f1")
it only logs the first metric.
Is it by design? Do I need to write a custom script if I want to log all these metrics by epochs/steps using Trainer API?
1 Like
You need to load each of those metrics separately, I don’t think the loader accepts a list.
1 Like
Can you please elaborate on your first statement. I tried these two variations - First one returns only recall
and second one throws and error saying tuple is not accepted
.
def compute_metrics(eval_pred):
# metric = load_metric("glue", "mrpc")
metric = load_metric("precision")
metric = load_metric("recall")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
def compute_metrics(eval_pred):
# metric = load_metric("glue", "mrpc")
metric1 = load_metric("precision")
metric2 = load_metric("recall")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric1.compute(predictions=predictions, references=labels), metric2.compute(predictions=predictions, references=labels)
You need to return a dictionary with the metrics you want (please checkout the course on the Trainer
and the Trainer video) so in this case:
def compute_metrics(eval_pred):
metric1 = load_metric("precision")
metric2 = load_metric("recall")
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
precision = metric1.compute(predictions=predictions, references=labels)["precision"]
recall = metric2.compute(predictions=predictions, references=labels)["recall"]
return {"precision": precision, "recall": recall}
13 Likes
Awesome I completed the course recently and missed that it returns a dictionary. Thanks for the help.
1 Like
Hi @sgugger , defining a custom compute_metrics is fine - but how do you call add_batch in each mini-batch within a single epoch when you want to log & compute multiple metrics towards the end of epoch (such as in the above case both precision and recall)? Could you please give an example of the same if possible?