CUDA memory suddenly run out of space when only used a quarter of memory

I fine tune FiD (Fusion in Decoder: paper, github) as a generative question answering task with a 13gb dataset which is a combination of ELI5 and MS MARCO dataset. Unfortunately, I get into trouble with CUDA out of memory problem. I am fine tuning this model with 2 RTX 3090 24gb on a single node. When the model is running after a number of steps, then it is stopped by CUDA out of memory but the number of steps are different in each case, sometimes CUDA memory runs out of space at step 69000 out of 776000 steps, or 23000 out of 776000 steps and so on. While I track the CUDA memory via watch nvidia-smi, the memory of 2 gpus is just around 7gb and 9gb occupied and suddenly one of them stop and notify that CUDA out of memory. I do not understand the reason.

I also put torch.cuda.empty_cache() after every 500 steps but it still fails. This is my training script:

export NGPU=2;
python -m torch.distributed.launch \
        --nproc_per_node=$NGPU \
        --train_data /home/jovyan/final_data/merge_ELI5_MS_MARCO.npz \
        --model_size base \
        --per_gpu_batch_size 1 \
        --n_context 4 \
        --name my_experiment \
        --checkpoint_dir checkpoint \
        --accumulation_steps 32 \
        --use_checkpoint \
        --total_steps 776004 \
        --optim adamw\
        --scheduler linear

This is the error:

RuntimeError: CUDA out of memory. Tried to allocate 502.00 MiB (GPU 1; 23.70 GiB total capacity; 20.55 GiB already allocated; 348.81 MiB free; 22.18 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
WARNING:torch.distributed.elastic.multiprocessing.api:Sending process 1970826 closing signal SIGTERM
WARNING:torch.distributed.elastic.multiprocessing.api:Unable to shutdown process 1970826 via 15, forcefully exitting via 9
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 1 (pid: 1970827) of binary: /home/jovyan/.conda_env/fid/bin/python
Traceback (most recent call last):
  File "/home/jovyan/.conda_env/fid/lib/python3.9/", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/jovyan/.conda_env/fid/lib/python3.9/", line 87, in _run_code
    exec(code, run_globals)
  File "/home/jovyan/.conda_env/fid/lib/python3.9/site-packages/torch/distributed/", line 193, in <module>
  File "/home/jovyan/.conda_env/fid/lib/python3.9/site-packages/torch/distributed/", line 189, in main
  File "/home/jovyan/.conda_env/fid/lib/python3.9/site-packages/torch/distributed/", line 174, in launch
  File "/home/jovyan/.conda_env/fid/lib/python3.9/site-packages/torch/distributed/", line 752, in run
  File "/home/jovyan/.conda_env/fid/lib/python3.9/site-packages/torch/distributed/launcher/", line 131, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/jovyan/.conda_env/fid/lib/python3.9/site-packages/torch/distributed/launcher/", line 245, in launch_agent
    raise ChildFailedError(
============================================================ FAILED
Root Cause (first observed failure):
  time      : 2023-01-07_08:47:07
  host      : cc6244288069
  rank      : 1 (local_rank: 1)
  exitcode  : 1 (pid: 1970827)
  error_file: <N/A>
  traceback : To enable traceback see:

Could you please help me to figure out what is going on? Thank you in advanced.