CUDA out of memory when using Trainer with compute_metrics

Sorry for the late reply, Actually I used logits[0] because the version I was using was passing both logits and labels in logits (maybe I was also misinterpreting them). I agree with you, the correct snippet should be with logits instead of logits[0] in the argmax.

pred_ids = torch.argmax(logits, dim=-1)
6 Likes