CUDA out of memory only when saving state

When the code reaches the accelerator.save_state(save_dir) section, I keep getting CUDA out of memory. I’m running with 16*A100 GPUs, training_batch = 5 and gradient_accumulation_steps=16 with pretty much the same code the codeparrot example

The part that breaks:

if step % save_checkpoint_steps == 0:
            logger.info('Evaluating and saving model checkpoint')
            eval_loss, preplexity = evaluate(
                model, accelerator, eval_dataloader, valid_batch_size, max_eval_steps, samples_per_step
            )
            log_metric(accelerator, step, {
                'loss/eval': eval_loss, 'preplexity': preplexity})
            accelerator.wait_for_everyone()
            save_dir = os.path.join(save_dir, f'step_{step}')
            accelerator.save_state(save_dir) # <=========== this line right here

            # allow only the last 5 checkpoints to be saved
            # remove any checkpoints that are older than 5
            if os.path.exists(f'{save_dir}/{step - 5}'):
                os.rmdir(f'{save_dir}/{step - 5}')

            if accelerator.is_main_process:
                hf_repo.push_to_hub(commit_message=f'Step {step}')
            model.train()

The output log:

04/28/2022 15:40:16 - INFO - accelerate.accelerator - Saving current state to /home/elonsalfati/.metiss/models/metisstg/model/step_500
04/28/2022 15:40:16 - INFO - accelerate.accelerator - Saving current state to /home/elonsalfati/.metiss/models/metisstg/model/step_500
Traceback (most recent call last):
  File "model-train.py", line 475, in <module>
    train(user_args)
  File "model-train.py", line 453, in train
    hf_repo, eval_dataloader, user_args.test_batch_size, user_args.max_test_steps, repo_dir
  File "model-train.py", line 371, in train_model
    accelerator.save_state(save_dir)
  File "/opt/conda/lib/python3.7/site-packages/accelerate/accelerator.py", line 694, in save_state
    weights = [self.get_state_dict(m) for m in self._models]
  File "/opt/conda/lib/python3.7/site-packages/accelerate/accelerator.py", line 694, in <listcomp>
    weights = [self.get_state_dict(m) for m in self._models]
  File "/opt/conda/lib/python3.7/site-packages/accelerate/accelerator.py", line 787, in get_state_dict
    state_dict[k] = state_dict[k].float()
RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 2; 39.59 GiB total capacity; 36.35 GiB already allocated; 38.19 MiB free; 36.54 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

Also, is there a reason all of the GPUs try to write the model state (based on the first log * GPU count)?

P.S. I’m also leveraging accelerate and deepspeed

Could you give us your Accelerate config? We haven’t done anything specific for DeepSpeed so it’s very likely that what is happening is a sharded optimizer state trying to be fully put on one GPU.

The save is only done on process 0, but the log is indeed duplicated (cc @muellerzr ), we’ll fix that too!

So far, what I did to keep going is that I’ve changed the accelerator.save_state(save_dir) with

unwrapped = accelerator.unwrap_model(model)
unwrapped.save_pretrained(save_dir, max_shard_size='40GB')

But I assume it doesn’t save the rest of the required files for the accelerator?


Here’s the configuration

❯ accelerate config
In which compute environment are you running? ([0] This machine, [1] AWS (Amazon SageMaker)): 0
Which type of machine are you using? ([0] No distributed training, [1] multi-CPU, [2] multi-GPU, [3] TPU): 2
How many different machines will you use (use more than 1 for multi-node training)? [1]: 1
Do you want to use DeepSpeed? [yes/NO]: yes
What should be your DeepSpeed's ZeRO optimization stage (0, 1, 2, 3)? [2]: 2
Where to offload optimizer states? [NONE/cpu/nvme]: cpu
How many gradient accumulation steps you're passing in your script? [1]: 16
How many processes in total will you use? [1]: 16
Do you wish to use FP16 or BF16 (mixed precision)? [NO/fp16/bf16]: fp16