Here are a few things:
- Make sure your model only returns logits and not extra tensors (as everything is accumulated on the GPU)
- Use
eval_accumulation_steps
to regularly offload the predictions on the GPU to the CPU (slower but will avoid this OOM error).