I don’t think it solves the issue, it only moves it to RAM instead of GPU. The real solution is introduced with preprocess_logits_for_metrics
function (here).
I leave here my specific solution (both functions):
def compute_metrics(pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions[0]
pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = tokenizer.pad_token_id
label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
rouge_output = rouge.compute(
predictions=pred_str,
references=label_str,
rouge_types=["rouge1", "rouge2", "rougeL", "rougeLsum"],
)
return {
"R1": round(rouge_output["rouge1"], 4),
"R2": round(rouge_output["rouge2"], 4),
"RL": round(rouge_output["rougeL"], 4),
"RLsum": round(rouge_output["rougeLsum"], 4),
}
def preprocess_logits_for_metrics(logits, labels):
"""
Original Trainer may have a memory leak.
This is a workaround to avoid storing too many tensors that are not needed.
"""
pred_ids = torch.argmax(logits[0], dim=-1)
return pred_ids, labels
BTW, proceeding in this way, you may not need to use eval_accumulation_steps=1
(that slows down the evaluation significantly).