Why eval_accumulation_steps takes so much memory

Notebooks don’t clear memory very well, so if you run things multiple times it may result in OOM.

You probably need to add a logit preprocessing step because it is saving the embeddings for each token rather than a single pooled one for each sample. This results in hundreds of times more memory consumed.


def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    # logits should be [bs, seq_len, hidden_size]
    return logits[:,0,:] # return CLS embedding


    # if doing mean pooled, you'll need more complicated logic


trainer  = Trainer(
...,
preprocess_logits_for_metrics=preprocess_logits_for_metrics
)
2 Likes