CUDA OOM on first backward pass after evaluation

I am finetuning a google/gemma-2-2b-it model using Trainer. I am using 8 x L4 GPUs, and ZeRO stage 3 partitioning via accelerate’s deepspeed. Once the first training epoch is over, evaluation happens and I see the resulting metrics on my terminal. The problem happens when the model tries to continue training. During the first backward pass of the second training epoch, I get the following OOM error:

File "/home/leobianco/hallucination_reduction/reward_model.py", line 220, in main
trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 2123, in train
return inner_training_loop(
        ^^^^^^^^^^^^^^^^^^^^
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 2481, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/transformers/trainer.py", line 3612, in training_step
self.accelerator.backward(loss, **kwargs)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/accelerate/accelerator.py", line 2233, in backward
self.deepspeed_engine_wrapped.backward(loss, **kwargs)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/accelerate/utils/deepspeed.py", line 186, in backward
self.engine.backward(loss, **kwargs)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/deepspeed/runtime/engine.py", line 2020, in backward
self.optimizer.backward(loss, retain_graph=retain_graph)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/deepspeed/utils/nvtx.py", line 18, in wrapped_fn
ret_val = func(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/deepspeed/runtime/zero/stage3.py", line 2259, in backward
self.loss_scaler.backward(loss.float(), retain_graph=retain_graph)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/deepspeed/runtime/fp16/loss_scaler.py", line 63, in backward
scaled_loss.backward(retain_graph=retain_graph)
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 581, in backward
torch.autograd.backward(
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/torch/autograd/__init__.py", line 347, in backward
_engine_run_backward(
File "/home/leobianco/hallucination_reduction/.venv/lib/python3.11/site-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 500.00 MiB. GPU 6 has a total capacity of 21.95 GiB of which 318.12 MiB is free. Including non-PyTorch memory, this process has 21.63 GiB memory in use. Of the allocated memory 12.71 GiB is allocated by PyTorch, and 8.57 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.

Indeed, the reserved but unallocated memory is large, so I tried setting PYTORCH_CUDA_ALLOC_CONF, but I get Warning: expandable_segments not supported on this platform (function operator()), so I guess this is not an option for me.

To be clear, I have set low evaluation accumulation steps as well as small batches:

--per_device_train_batch_size 1 \
--per_device_eval_batch_size 1 \
--gradient_accumulation_steps 1 \
--eval_accumulation_steps 1 \

I am also using preprocess_logits_for_metrics as follows:

def preprocess_logits_for_metrics(logits, labels):
    """Important for avoiding OOMing, since during evaluation all logits are kept
    in the GPU, and even lowering eval_accumulation_steps did not work.

    answer_idx is the index of "Yes" or "No", which is effectively the last token
    before the -100 mask.

    last_token_idx refers to the last token that is not the answer, typically the
    ":" in the end of the prompt.

    yes_token_id and no_token_id are defined outside the scope of the function.
    """

    answer_idx = torch.argmax((labels == -100).to(dtype=torch.int), dim=1) - 1
    last_token_idx = answer_idx - 1
    processed_logits = logits[
      torch.arange(len(last_token_idx)),
      last_token_idx
    ][:, [yes_token_id, no_token_id]]

    return processed_logits

Here is my accelerate / deepspeed config:

compute_environment: LOCAL_MACHINE
debug: false
deepspeed_config:
  gradient_accumulation_steps: 1
  offload_optimizer_device: cpu
  offload_param_device: cpu
  zero3_init_flag: false
  zero3_save_16bit_model: false
  zero_stage: 3
distributed_type: DEEPSPEED
downcast_bf16: 'no'
enable_cpu_affinity: false
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 8
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

It is very weird to me that the model could train during the first epoch, but not on the second after evaluation. I would greatly appreciate any help on clarifying the issue!

EDIT: changing the seed allowed the model to finish training for three epochs (default). However, the increase in memory usage through the epochs was worrying, I will need to train a larger model for more epochs eventually, so I do not think the problem really is solved.

1 Like