Notebooks don’t clear memory very well, so if you run things multiple times it may result in OOM.
You probably need to add a logit preprocessing step because it is saving the embeddings for each token rather than a single pooled one for each sample. This results in hundreds of times more memory consumed.
def preprocess_logits_for_metrics(logits, labels):
if isinstance(logits, tuple):
# Depending on the model and config, logits may contain extra tensors,
# like past_key_values, but logits always come first
logits = logits[0]
# logits should be [bs, seq_len, hidden_size]
return logits[:,0,:] # return CLS embedding
# if doing mean pooled, you'll need more complicated logic
trainer = Trainer(
...,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
)