Compute metrics causes OOM

When finetuning FLAN-T5-small for translation, if I add compute metrics it causes OOM in the CPU,

here is the function I have used:

def compute_metrics(eval_pred):
        preds, labels = eval_pred
        if isinstance(preds, tuple):
            preds = preds[0]

        decoded_preds = tokenizer.decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)

        result = bleu.compute(preds=decoded_preds, references=decoded_labels)
        result = {"bleu": result["score"]}

My Trainer:

trainer = Seq2SeqTrainer(
        model,
        tokenizer=tokenizer,
        data_collator=data_collator,
        train_dataset=tokenized_dataset["train"],
        eval_dataset=tokenized_dataset["test"],
        # BUG: in compute metrics
        # compute_metrics=compute_metrics,
        args=Seq2SeqTrainingArguments(
            # REPO name
            output_dir=output_dir,
            run_name="t5-small-Ne-Ro-translit-test",
            predict_with_generate=True,
            # BATCH size
            per_device_train_batch_size=8,
            per_device_eval_batch_size=8,
            eval_accumulation_steps=2,
            gradient_accumulation_steps=4,
            gradient_checkpointing=True,
            # CHeckpointing
            save_strategy="steps",
            evaluation_strategy="epoch",
            save_steps=2000,
            logging_steps=100,
            # HYPER-PARAMETERS
            learning_rate=1e-3,
            max_steps = 1000,
            # num_train_epochs=1,
            # weight_decay=0.01,
            fp16=False,
            # MODEL PUSHING TO HUB
            report_to="wandb",
            push_to_hub=False,
            hub_private_repo=False,
        ),
    )
    model.config.use_cache = False
    trainer.train()

Solved the issue:
the compute_metrics didnot returned any loss , so it was causing OOM.

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.