CUDA out of memory when using Trainer with compute_metrics

For anyone stuck with this problem in the case of Vision Transformers, here’s the corresponding function.

def preprocess_logits_for_metrics_fn(logits_tuple, labels):
     # Unpack logits tuple
    cls_logits = logits_tuple[1]
    box_preds = logits_tuple[2]

    # Detach and move to CPU (important for memory and multiprocessing)
    cls_logits = cls_logits.detach().cpu()
    box_preds = box_preds.detach().cpu()

    return (cls_logits, box_preds), labels

to be used in the Trainer as

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset["train"],
    eval_dataset=dataset["val"],
    processing_class=image_processor,
    data_collator=collate_fn,
    compute_metrics=eval_compute_metrics_fn,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics_fn,
)
2 Likes