Unispeech-sat-base-plus-sv evalution runs out of VRAM

Hi! Currently I’m working on some utterance classification problem. Previously I was using Wav2Vec2 (facebook/wav2vec2-base) model and everything was fine. Now I tried to use this new model using the same code. The training process goes fine, but when it comes to evaluation, it runs out of CUDA memory:

 43%|████▎     | 55/129 [00:41<01:09,  1.07it/s]e[ATraceback (most recent call last):
  File "/a2e_workspace/train.py", line 331, in <module>
    trainer.train()
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1399, in train
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 1521, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 2158, in evaluate
    output = eval_loop(
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer.py", line 2341, in evaluation_loop
    preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 106, in nested_concat
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 106, in <genexpr>
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 106, in nested_concat
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 106, in <genexpr>
    return type(tensors)(nested_concat(t, n, padding_index=padding_index) for t, n in zip(tensors, new_tensors))
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 108, in nested_concat
    return torch_pad_and_concatenate(tensors, new_tensors, padding_index=padding_index)
  File "/opt/conda/lib/python3.8/site-packages/transformers/trainer_pt_utils.py", line 76, in torch_pad_and_concatenate
    result = tensor1.new_full(new_shape, padding_index)
RuntimeError: CUDA out of memory. Tried to allocate 480.00 MiB (GPU 0; 15.78 GiB total capacity; 13.21 GiB already allocated; 415.75 MiB free; 13.99 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

The GPU has 16 GB VRAM.

Batch size parameters:

  per_device_train_batch_size: 4
  per_device_eval_batch_size: 4
  gradient_accumulation_steps: 4

The code is pretty much default, it’s based on this guide

The strange thing is that the size of the model is approximately the same as Wav2Vec2, and also memory always runs out exactly on evalution. So maybe there is some memory leak bug in the code.

I’m gonna test this on 32 GB GPU and report the results.

I figured out why this happens. There is a line 1783 in transformers/models/unispeech_sat/modeling_unispeech_sat.py:

output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states

So even if you pass output_hidden_states=None in forward pass, the output will anyways contain hidden states (because config.use_weighted_layer_sum=True by default). And if you are using output in a form such like that:

@dataclass
class SpeechClassifierOutput(ModelOutput):
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None

This will actually contain all the hidden states.

Evaluation round will accumulate all the hidden states for all samples, therefore you will run out of memory. So I just manually set hidden_states=None in the output.

Hopefully this will help someone.