Out of memory error when using trainer & output_hidden_states

Hi all,

I have a small dataset, and I am trying to fine tune a bert-base-uncased, with ModelForSequenceClassification, using Trainer.

Everything is ok when using:

model = AutoModelForSequenceClassification.from_pretrained(
                pretrained_model_name,
                problem_type="multi_label_classification",
                output_hidden_states=False,
                num_labels=len(labels),
                id2label=id2label,
                label2id=label2id)

Finetunning completes, and GPU RAM maxes at 4280MiB (monitored with nvidia-smi).

But simply changing output_hidden_states to True, fails with CUDA out of memory:

  7%|????????                                                                                                               | 338/5070 [00:50<11:22,  6.93it/s]
***** Running Evaluation *****
  Num examples = 1896
  Batch size = 16
                                                                                                                                                              Traceback (most recent call last):????????????????????????????????????????????????????????????????????                         | 94/119 [00:15<00:08,  3.08it/s]
  File "/home/petasis/semeval-touche/Human-Value-Detection/petasis/clustering.py", line 93, in <module>
    trainer.train()
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1527, in train
    return inner_training_loop(
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1867, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2115, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2811, in evaluate
    output = eval_loop(
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer.py", line 3016, in evaluation_loop
    preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer_pt_utils.py", line 113, in nested_concat
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer_pt_utils.py", line 113, in <genexpr>
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer_pt_utils.py", line 113, in nested_concat
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer_pt_utils.py", line 113, in <genexpr>
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer_pt_utils.py", line 115, in nested_concat
    return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
  File "/home/petasis/.local/lib/python3.10/site-packages/transformers/trainer_pt_utils.py", line 74, in torch_pad_and_concatenate
    return torch.cat((tensor1, tensor2), dim=0)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 802.00 MiB (GPU 0; 23.64 GiB total capacity; 21.24 GiB already allocated; 526.25 MiB free; 22.22 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
  7%|????????                                                                                                               | 338/5070 [01:06<15:34,  5.06it/s]

My card has 24GB of VRAM. It feels like memory is leaked somewhere, but I haven’t modified anything, and I am using library classes.
How can I fix this?

George