Evaluation became slower and slower during Trainer.train()

! I think I found a way to solve it! :grin:
According to https://discuss.huggingface.co/t/cuda-out-of-memory-when-using-trainer-with-compute-metrics/2941/2, I think the probably reason is that “When computing metrics inside the Trainer , your predictions are all gathered together on the device (GPU/TPU) and only passed back to the CPU at the end (because that operation can be slow).”

But when computing we do not need all the logits (just the largest one’s idx). So I solve the problem by introducing with preprocess_logits_for_metrics function:

    def compute_metrics_acc(tokenizer):
        def compute_metric(eval_preds):
            preds, targets = eval_preds
            preds= np.where(preds != -100, preds, tokenizer.pad_token_id)
            targets= np.where(targets != -100, targets, tokenizer.pad_token_id)
            preds = tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            targets = tokenizer.batch_decode(targets, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            correct = 0
            assert len(preds) == len(targets)
            for idx, pred in enumerate(preds):
                reference = targets[idx]
                reference = extract_ans(reference)
                extract_pred = extract_ans(pred)
                best_option = extract_pred
                if reference == best_option and reference != False:
                    correct +=1 
            return {'accuracy': 1.0*correct/len(targets)}
        return compute_metric

def preprocess_logits_for_metrics(logits, labels):
        """
        Original Trainer may have a memory leak. 
        This is a workaround to avoid storing too many tensors that are not needed.
        """
        pred_ids = torch.argmax(logits, dim=-1)
        return pred_ids

and pass it to the trainer.
I left my trainer setup here:

    trainer = SFTTrainer(
        model=base_model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        peft_config=peft_config,
        packing=script_args.packing,
        max_seq_length=1024,
        tokenizer=tokenizer,
        args=training_args,
        data_collator=collator,
        compute_metrics=compute_metric,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
        formatting_func=prepare_sample_text,
    )