CUDA OOM error when using data-distributed mode on AWS p4d.24xlarge instance

I have a train.py script that works (very slowly) on a g5 instance with 1 GPU and 24 GB memory. When I deploy the same script to a multi-GPU instance and batch size 1, I am getting CUDA OOM error.

train.py

model = AutoModelForCasualLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", use_cache=False, quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True)
lora_config = LoraConfig(....)
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    bf16=True,
   .....
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
   args=training_args,
   train_dataset=dataset["train"],
   eval_dataset=dataset["test"],
   max_sequence_length=4000,
   dataset_kwargs={"add_special_tokens": False, "append_concat_token:False},
   compute_metrics=compute_metrics,
   ...
)

trainer.train()

Train using Sagemaker estimator:

he = HuggingFace(
    instance_type="ml.p4d.24xlarge",
    instance_count=1,
   transformers_version="4.36.0",
   pytorch_version= "2.1.0",
   py_version="py310",
   distribution={"smdistributed": {"dataparallel": {"enabled": True}}}
)

Error:

CUDA out of memory. Tried to allocate 20 Mib. GPU 0 has a total capacity of 39.9 Gib of which 1.38 MiB is free. Process 135 has 0 bytes of memory in use. Process 133 has 0 bytes memory in use. ...Of the allocated memory 4.24 Gib is allocated by PyTorch, and 113 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation. 

So what is happening here, and how can I fix it? The model is about 15 GB, so not sure why it cannot fit into 40 GB GPU.

Your batch and optimiser backlog also need to fit into VRAM.

Try using FP16 and/or Gradient Checkpointing, Gradient Accumulation

I tried the FP16, but it didn’t work, still getting CUDA OOM errors from multiple GPUs.

Here is the updated script:

   # load the model in 4 bits
    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

    model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",
                                                 torch_dtype=torch.bfloat16,
                                                 attn_implementation="flash_attention_2",
                                                 use_cache=False,
                                                 quantization_config=quantization_config)
   
    # process 2 throws the error here
    logger.info("Model is ready")

    tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True)
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = 'left'

    lora_config = LoraConfig(
        r=32,
        lora_alpha=16,
        lora_dropout=0.09,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
        bias="none",
        task_type="CAUSAL_LM"
    )
   
    training_args = TrainingArguments(
        output_dir="/tmp/run",
        evaluation_strategy="steps",
        save_strategy="steps",        
        eval_steps=100,
        logging_steps=100,
        per_device_train_batch_size=1,
        per_device_eval_batch_size=1,
        gradient_accumulation_steps=4,
        eval_accumulation_steps=1,
        gradient_checkpointing=True,
        learning_rate=2e-4,
        logging_strategy="steps",
        num_train_epochs=args.epochs,
        optim="adamw_torch_fused",
        warmup_ratio=0.03,
        max_grad_norm=0.3,
        save_steps=200,
        lr_scheduler_type="constant",
        fp16=True,
        load_best_model_at_end=True,
        save_total_limit=1,       
        report_to="wandb",        
        run_name=args.experiment_name
    )

    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["test"],
        peft_config=lora_config,
        max_seq_length=4096,
        tokenizer=tokenizer,
        dataset_text_field="instruction",
        packing=False,
        dataset_kwargs={
            "add_special_tokens": False,
            "append_concat_token": False
        }        
    )

    logger.info("Starting training...")

    trainer.train()

And the error:

[1,mpirank:2,algo-1]<stderr>:#015Loading checkpoint shards: 100%|██████████| 3/3 [00:19<00:00,  6.40s/it]
[1,mpirank:2,algo-1]<stderr>:#015Loading checkpoint shards: 100%|██████████| 3/3 [00:19<00:00,  6.59s/it]
[1,mpirank:3,algo-1]<stdout>:2024-04-10 15:10:44,180 - __main__ - INFO - Model is ready
[1,mpirank:0,algo-1]<stdout>:2024-04-10 15:10:44,230 - __main__ - INFO - Model is ready
[1,mpirank:0,algo-1]<stderr>:/opt/conda/lib/python3.10/site-packages/accelerate/state.py:306: UserWarning: OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at 6 to improve oob performance.
[1,mpirank:0,algo-1]<stderr>:  warnings.warn(
[1,mpirank:3,algo-1]<stderr>:/opt/conda/lib/python3.10/site-packages/accelerate/state.py:306: UserWarning: OMP_NUM_THREADS/MKL_NUM_THREADS unset, we set it at 6 to improve oob performance.
[1,mpirank:3,algo-1]<stderr>:  warnings.warn(
[1,mpirank:2,algo-1]<stderr>:Traceback (most recent call last):
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[1,mpirank:2,algo-1]<stderr>:    return _run_code(code, main_globals, None,
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
[1,mpirank:2,algo-1]<stderr>:    exec(code, run_globals)
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/mpi4py/__main__.py", line 7, in <module>
[1,mpirank:2,algo-1]<stderr>:    main()
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/mpi4py/run.py", line 230, in main
[1,mpirank:2,algo-1]<stderr>:    run_command_line(args)
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/mpi4py/run.py", line 47, in run_command_line
[1,mpirank:2,algo-1]<stderr>:    run_path(sys.argv[0], run_name='__main__')
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/runpy.py", line 289, in run_path
[1,mpirank:2,algo-1]<stderr>:    return _run_module_code(code, init_globals, run_name,
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/runpy.py", line 96, in _run_module_code
[1,mpirank:2,algo-1]<stderr>:    _run_code(code, mod_globals, init_globals,
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
[1,mpirank:2,algo-1]<stderr>:    exec(code, run_globals)
[1,mpirank:2,algo-1]<stderr>:  File "train.py", line 81, in <module>
[1,mpirank:2,algo-1]<stderr>:    model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2",
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 561, in from_pretrained
[1,mpirank:2,algo-1]<stderr>:    return model_class.from_pretrained(
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3558, in from_pretrained
[1,mpirank:2,algo-1]<stderr>:    dispatch_model(model, **device_map_kwargs)
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/accelerate/big_modeling.py", line 417, in dispatch_model
[1,mpirank:2,algo-1]<stderr>:    attach_align_device_hook_on_blocks(
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 608, in attach_align_device_hook_on_blocks
[1,mpirank:2,algo-1]<stderr>:    add_hook_to_module(module, hook)
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 157, in add_hook_to_module
[1,mpirank:2,algo-1]<stderr>:    module = hook.init_hook(module)
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/accelerate/hooks.py", line 275, in init_hook
[1,mpirank:2,algo-1]<stderr>:    set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
[1,mpirank:2,algo-1]<stderr>:  File "/opt/conda/lib/python3.10/site-packages/accelerate/utils/modeling.py", line 376, in set_module_tensor_to_device
[1,mpirank:2,algo-1]<stderr>:    new_value = old_value.to(device)
[1,mpirank:2,algo-1]<stderr>:torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB. GPU 0 has a total capacty of 39.39 GiB of which 1.38 MiB is free. Process 133 has 0 bytes memory in use. Including non-PyTorch memory, this process has 0 bytes memory in use. Process 134 has 0 bytes memory in use. Process 129 has 0 bytes memory in use. Process 135 has 0 bytes memory in use. Process 130 has 0 bytes memory in use. Process 132 has 0 bytes memory in use. Process 136 has 0 bytes memory in use. Of the allocated memory 4.04 GiB is allocated by PyTorch, and 61.14 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
[1,mpirank:5,algo-1]<stdout>:2024-04-10 15:10:44,999 - __main__ - INFO - Model is ready
[1,mpirank:4,algo-1]<stdout>:2024-04-10 15:10:45,031 - __main__ - INFO - Model is ready
[1,mpirank:7,algo-1]<stdout>:2024-04-10 15:10:45,062 - __main__ - INFO - Model is ready
[1,mpirank:1,algo-1]<stdout>:2024-04-10 15:10:45,062 - __main__ - INFO - Model is ready

My question is how come the same script worked in an instance with 1 GPU with 24 GB memory, and it doesn’t work in an instance with 8 GPUs where each GPU has 40 GB memory? Does SageMaker data-distributed mode work with my setup?

Where is your code to set up the distributed training? Im not sure how it works with sagemaker.

I’m using below reference, my understanding is that all I need to do is to add the distribution parameter into the SageMaker Huggingface estimator:

Ref: Run training on Amazon SageMaker

# configuration for running training on smdistributed Data Parallel
distribution = {'smdistributed':{'dataparallel':{ 'enabled': True }}}

huggingface_estimator = HuggingFace(
    entry_point='train.py',
    source_dir='./scripts',
    instance_type='ml.p4d.24xlarge',
    instance_count=1,
    transformers_version='4.36.0',
    pytorch_version='2.1.0',
    py_version='py310',
    role=role,
    hyperparameters=hyperparameters,
    disable_output_compression=True,
    metric_definitions=metric_definitions,   
    base_job_name=base_job_name,
    distribution=distribution
)
huggingface_estimator.fit({'train': training_input_path, 'test': validation_input_path}, wait=True)

I must admit sagemaker isn’t my area, though I think there is some clue is the fact that it is GPU 0 throwing the error, perhaps the main device is experiencing additional overhead. Though I am unsure why that would be.

I think one of the GPUs need to handle the overhead to average the weight updates coming from the rest of the GPUs. But it beats the purpose to use a bigger instance if it takes more than 40 GB to host the model (with quantization to 4 bits) + overhead.