LLama3-8B - FSDP + QLORA results in OOM with 4 A40's

Hardware:
CPU: Xeon® E5-2630 v2 but limited to 16GB as this is what the vast.ai instance has.
GPU: 4x A40 → Total of 180GB

OS
Linux

python
3.10

cuda
12.2

packages:

torch==2.3.1
transformers==4.41.2
peft==0.11.1
datasets==2.20.0
accelerate==0.31.0
evaluate==0.4.1
bitsandbytes==0.43.1
huggingface_hub==0.23.4
trl==0.9.4

Issue

Introduction

Hi!
I’m trying to fine-tune LLama3-8B on a summarization dataset of about 1500 instances. The dataset contains long documents, often over 8K tokens. I want to use FSDP + QLORA to try and finetune LLama3 8B. When following this guide I was very hopeful this was possible on my setup as I’m finetuning a 8B version instead of the 70B version.

I’m following these two guides as inspiration:
bitsandbytes Guide
Phil Schmid Guide

Phil Schmid’s guide mentions the following:
Expected Memory usage:
Full-finetuning with FSDP needs ~16X80GB GPUs
FSDP + LoRA needs ~8X80GB GPUs
FSDP + Q-Lora needs ~2x40GB GPUs
FSDP + Q-Lora + CPU offloading needs 4x24GB GPUs, with 22 GB/GPU and 127 GB CPU RAM with a sequence length of 3072 and a batch size of 1.
Note: To NOT CPU offloading you need to change the value of fsdp and remove offload. This only works on > 40GB GPUs since it requires more memory.

Accelerate config setup:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch: BACKWARD_PRE
  fsdp_cpu_ram_efficient_loading: false #Was true before
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: FULL_SHARD
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_sync_module_states: true
  fsdp_use_orig_params: false
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Code

quantization_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_storage=torch.bfloat16,
            )

model = AutoModelForCausalLM.from_pretrained(
            'meta-llama/Meta-Llama-3-8B', 
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            attn_implementation="sdpa",
            use_cache=False
            )
        
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Meta-Llama-3-8B)
tokenizer.pad_token = tokenizer.eos_token

lora_config = LoraConfig(
            r= 8,
            lora_alpha=16,
            lora_dropout=0.1,
            target_modules = ["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
            task_type= 'CAUSAL_LM',
            bias= 'none',

        )

model = get_peft_model(model, lora_config)

training_args = TrainingArguments(
            output_dir = os.path.join('results', model_id, 'output'),
            num_train_epochs = 40,
            per_device_train_batch_size = 1,
            per_device_eval_batch_size = 1, 
            gradient_accumulation_steps = True,
            warmup_ratio = args.warmup_ratio,
            weight_decay = args.weight_decay,
            logging_dir = os.path.join('results', model_id, 'logs'),
            remove_unused_columns = False,        
            load_best_model_at_end = True,
            metric_for_best_model = True,
            save_strategy= "epoch",
            save_total_limit= 2,
            evaluation_strategy = "epoch",
            label_names=["labels"],
            report_to = "wandb",
            logging_strategy = "epoch",
            run_name = model_id,
            eval_accumulation_steps = 1,
            hub_model_id = f"{model_id}",
            gradient_checkpointing= True,
            fp16= args.fp16,
            bf16= args.bf16,
            ddp_find_unused_parameters = True,
            gradient_checkpointing_kwargs= {'use_reentrant': False},
        )

trainer = SFTTrainer(
            model = model, 
            tokenizer = tokenizer, 
            args = training_args,
            train_dataset = dataset["train"],
            eval_dataset = dataset["validation"],
            max_seq_length = context_length_abstractive_model, #8192 
            callbacks = [EarlyStoppingCallback(early_stopping_patience = args.early_stopping_patience)],
            peft_config = lora_config,
            packing= True
            )

trainer.train()

Start training

accelerate launch training.py --bf16

errors:

First is followed the guides exactly and set fsdp_cpu_ram_efficient_loading to true. But when i do this, sometimes the OS would run give a SIGKILL(9) error and stop the process:
Scherm­afbeelding 2024-06-17 om 11 09 46
This makes sense as Phil Schmid also recommends a pretty hefty CPU memory: 127 GB CPU RAM with a sequence length of 3072 for a batch size of 1.

But oddly enough, I can run the script currently with fsdp_cpu_ram_efficient_loading_ with either true or false and not receive the SIGKILL(9) error. However, in both situations I do get the following OOM error:

rank1]: Traceback (most recent call last):
[rank1]:   File "/workspace/Thesis/training.py", line 703, in <module>
[rank1]:     trainer.train()
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 440, in train
[rank1]:     output = super().train(*args, **kwargs)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
[rank1]:     return inner_training_loop(
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
[rank1]:     tr_loss_step = self.training_step(model, inputs)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3250, in training_step
[rank1]:     self.accelerator.backward(loss)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2134, in backward
[rank1]:     loss.backward(**kwargs)
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
[rank1]:     torch.autograd.backward(
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
[rank1]:     _engine_run_backward(
[rank1]:   File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
[rank1]:     return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
[rank1]: torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 32.13 GiB. GPU  has a total capacity of 44.35 GiB of which 20.85 GiB is free. Process 787350 has 23.49 GiB memory in use. Of the allocated memory 18.22 GiB is allocated by PyTorch, and 4.84 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
W0617 09:10:40.805000 140644428781376 torch/distributed/elastic/multiprocessing/api.py:851] Sending process 3244 closing signal SIGTERM

As you can see, it seems that during the backward pass, the model runs out of memory. I find this pretty odd as I (should/probably) have enough GPU memory to accomodate for the 8B FSDP and QLORA setup.

Possible limitations

CPU has too little ram. The offloading isn’t possible because we only have 16GB of CPU ram. But following Phil Schmid’s guide and not offloading to the CPU would suffice still, as we use 4 A40’s. This is even more odd when you think that I’m using an 8B version, instead of the 70B versions that are used in both guides.

Not using Flash Attention 2 could also be an issue, but as seen in Phil Schmid’s guide, SDPA can also be used.

Sequence length is too long, causing OOM. I tried setting the max_sequence_length to 512, but this didn’t have any impact on the OOM issue.

Caveat

When i first dove into the rabbithole of FSDP and QLORA I started out simple and just used the following code:

quantization_config = BitsAndBytesConfig(
            load_in_8bit=True,
            )

model = AutoModelForCausalLM.from_pretrained(
            'meta-llama/Meta-Llama-3-8B', 
            quantization_config=quantization_config,
            torch_dtype=torch.bfloat16,
            device_map = 'auto'
            use_cache=False if args.gradient_checkpointing else True,
            )

I launched the code with:

python3 training.py

This didn’t result in an OOM error and I was able to train for 100 steps. This took quite long however and would become too expensive for me as the training would probably last over 200 hours… I could see that the GPU memory was utilized pretty well and all GPU’s were utilized up until 40GB or so. Because this took quite long, I wanted to use QLORA. But I couldn’t just use QLORA device_map =‘auto’ together. That’s why I resorted to FSDP in combination with QLORA.

I don’t really know why using QLORA in combination with FSDP would then result in the OOM again, making me even more confused.

If you have any ideas, please let me know as I’m getting a bit frustrated after being stuck on this for a few days!

I fixed the issue! There were somme things I did wrong:

  • fsdp_cpu_ram_efficient_loading needs to be set to true → I mistook this parameter for fsdp_offload_params, which needs to be set to false! When fsdp_cpu_ram_efficient_loading is set to false, it causes the OS to overload, causing the SIGKILL(9).
    -gradient_checkpointing_kwargs= {‘use_reentrant’: False}, needs to be set to True!! When this is set to False, this will take up a lot of GPU VRAM.
  • ddp_find_unused_parameters = True, needs to be un-assigned or set to false. It is automatically set to False when gradient checkpointing is used
  • Using get_peft_model , together with passing a PEFT configuration to the SFTTrainer causes an error:
raceback (most recent call last):
  File "/workspace/Thesis/training.py", line 705, in <module>
    trainer.train()
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 440, in train
    output = super().train(*args, **kwargs)
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 1885, in train
    return inner_training_loop(
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 2216, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/transformers/trainer.py", line 3250, in training_step
    self.accelerator.backward(loss)
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/accelerate/accelerator.py", line 2134, in backward
    loss.backward(**kwargs)
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/_tensor.py", line 525, in backward
    torch.autograd.backward(
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/workspace/Thesis/venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 744, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

So be sure not to use get_peft_model to wrap your model if you also pass a peft_config to the SFTTrainer!!

Looking back, it makes sense that my script worked with:

python3 training.py

When this was run, we were performing DP instead of FSDP. Then it also makes sense that training was 200 hours.

With my current setup I’m able to reduce training to 60-ish hours. With a per_device_train_batch_size of 1 and gradient_accumulation_steps of 4, the memory of my GPUs are almost maxed out. I think this is due to the long sequence length that is used.

If anyone has any recommendations on how to speed up the remainder of the training process, feel free to let me know!