For anyone stuck with this problem in the case of Vision Transformers, here’s the corresponding function.
def preprocess_logits_for_metrics_fn(logits_tuple, labels):
# Unpack logits tuple
cls_logits = logits_tuple[1]
box_preds = logits_tuple[2]
# Detach and move to CPU (important for memory and multiprocessing)
cls_logits = cls_logits.detach().cpu()
box_preds = box_preds.detach().cpu()
return (cls_logits, box_preds), labels
to be used in the Trainer as
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset["train"],
eval_dataset=dataset["val"],
processing_class=image_processor,
data_collator=collate_fn,
compute_metrics=eval_compute_metrics_fn,
preprocess_logits_for_metrics=preprocess_logits_for_metrics_fn,
)