I am trying to extract embeddings from a DistilBertModel and therefore need to perform inference on it with a large dataset.
When performing inference using predict()
on DistilBertForSequenceClassification either CUDA out of memory or a significant slowdown occurs.
As suggested in other posts (Similar Forum Post) I used eval_accumulation_steps
to avoid the OOM issue, which however still leads to an extreme slowdown. Varying eval_accumulutation_steps
values (5, 500, 1k) and batch_sizes
did not seem to resolve the issue.
Another suggested approach was using preprocess_logits_for_metrics
, however, this would also only partially decrease the slowdown. The only real solution was using the function to return torch.tensor(0)
. This is a nonsensical return value but proves that there are some underlying memory issues.
def preprocess_logits_for_metrics_inference(logits, labels):
"""
Original Trainer may have a memory leak.
This is a workaround to avoid storing too many tensors that are not needed.
"""
last_hidden_state = logits[1][0]
output = last_hidden_state
return output
Also using prediction_loss_only=True
avoids this error indicating that the there is a memory error when processing the outputs during inference/evaluation.
Below you can find the full training and eval code. How come applying both eval_accumulation_steps
and preprocess_logits_for_metrics
cannot resolve the issue?
model = DistilBertForSequenceClassification.from_pretrained("johannes-garstenauer/distilbert-heaps-masked",
num_labels=2,
id2label=id2label, label2id=label2id)
from transformers import DataCollatorWithPadding, TrainingArguments, Trainer
data_collator = DataCollatorWithPadding(tokenizer=pretrained_tokenizer)
model_name = "bin_clean_seq_class"
print(f"Model name: {model_name}")
training_args = TrainingArguments(
output_dir=model_name,
learning_rate=2e-5,
per_device_train_batch_size=64,
per_device_eval_batch_size=64,
num_train_epochs=3,
weight_decay=0.01,
evaluation_strategy="epoch",
eval_accumulation_steps=500,
)
def preprocess_logits_for_metrics(logits, labels):
"""
Original Trainer may have a memory leak.
This is a workaround to avoid storing too many tensors that are not needed.
"""
last_hidden_state = logits[1][0]
output = last_hidden_state
return output
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset_split["train"],
eval_dataset=tokenized_dataset_split["test"],
tokenizer=pretrained_tokenizer,
data_collator=data_collator,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
)
pred = trainer.predict(test_dataset=tokenized_dataset_split["test"])