Evaluation and compute_metrics slowdown

I am trying to extract embeddings from a DistilBertModel and therefore need to perform inference on it with a large dataset.

When performing inference using predict() on DistilBertForSequenceClassification either CUDA out of memory or a significant slowdown occurs.

As suggested in other posts (Similar Forum Post) I used eval_accumulation_steps to avoid the OOM issue, which however still leads to an extreme slowdown. Varying eval_accumulutation_steps values (5, 500, 1k) and batch_sizes did not seem to resolve the issue.

Another suggested approach was using preprocess_logits_for_metrics, however, this would also only partially decrease the slowdown. The only real solution was using the function to return torch.tensor(0). This is a nonsensical return value but proves that there are some underlying memory issues.

def preprocess_logits_for_metrics_inference(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """

    last_hidden_state = logits[1][0]
    output = last_hidden_state
    return output

Also using prediction_loss_only=True avoids this error indicating that the there is a memory error when processing the outputs during inference/evaluation.

Below you can find the full training and eval code. How come applying both eval_accumulation_steps and preprocess_logits_for_metrics cannot resolve the issue?

model = DistilBertForSequenceClassification.from_pretrained("johannes-garstenauer/distilbert-heaps-masked",
                                                            num_labels=2,
                                                            id2label=id2label, label2id=label2id)

from transformers import DataCollatorWithPadding, TrainingArguments, Trainer

data_collator = DataCollatorWithPadding(tokenizer=pretrained_tokenizer)

model_name = "bin_clean_seq_class"
print(f"Model name: {model_name}")

training_args = TrainingArguments(
    output_dir=model_name,
    learning_rate=2e-5,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    eval_accumulation_steps=500,
)

def preprocess_logits_for_metrics(logits, labels):
    """
    Original Trainer may have a memory leak.
    This is a workaround to avoid storing too many tensors that are not needed.
    """

    last_hidden_state = logits[1][0]
    output = last_hidden_state
    return output

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset_split["train"],
    eval_dataset=tokenized_dataset_split["test"],
    tokenizer=pretrained_tokenizer,
    data_collator=data_collator,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics
)

pred = trainer.predict(test_dataset=tokenized_dataset_split["test"])
2 Likes