Issue with LoRA Adapter Loading on Multiple GPUs during Fine-Tuning with Accelerate and SFTTrainer

Hello Hugging Face Community,

I’m working on fine-tuning a pre-trained LLaMA 3.1 model using LoRA adapters with the goal of performing additive tuning—continuing to fine-tune an existing LoRA adapter or adding a new one. I’m using the transformers, peft, trl, and accelerate libraries for this task.

When I attempt to load a pre-trained LoRA adapter onto my base model and distribute the model across multiple GPUs using Accelerate , I observe that the GPU memory did not seem to be normal. I guess that probably all LoRA adapter weights are loaded only onto GPU 0, while the base model is correctly distributed across all GPUs. This leads to an imbalance in GPU memory usage, with GPU 0 consuming significantly more memory than the others. Same situation did not happen if I just load a base llama3.1 model and do the first round Lora fine-tuning, which means to establish a PEFT model from scratch based on a. LoraConfig, instead of loading an adaptor and then setting it as trainable.

Does anyone encounter with a similar problem? Any help or comments will be highly appreciated.

Thanks in advance.

here is my nvidia-smi output

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A   1939477      C   ...envs/torch/bin/python3.10    10732MiB |
|    0   N/A  N/A   1939478      C   ...envs/torch/bin/python3.10     1116MiB |
|    0   N/A  N/A   1939479      C   ...envs/torch/bin/python3.10     1116MiB |
|    0   N/A  N/A   1939480      C   ...envs/torch/bin/python3.10     1116MiB |
|    0   N/A  N/A   1939481      C   ...envs/torch/bin/python3.10     1116MiB |
|    0   N/A  N/A   1939482      C   ...envs/torch/bin/python3.10     1116MiB |
|    0   N/A  N/A   1939483      C   ...envs/torch/bin/python3.10     1116MiB |
|    0   N/A  N/A   1939484      C   ...envs/torch/bin/python3.10     1116MiB |
|    1   N/A  N/A   1939478      C   ...envs/torch/bin/python3.10    10652MiB |
|    2   N/A  N/A   1939479      C   ...envs/torch/bin/python3.10    10652MiB |
|    3   N/A  N/A   1939480      C   ...envs/torch/bin/python3.10    10652MiB |
|    4   N/A  N/A   1939481      C   ...envs/torch/bin/python3.10    10652MiB |
|    5   N/A  N/A   1939482      C   ...envs/torch/bin/python3.10    10652MiB |
|    6   N/A  N/A   1939483      C   ...envs/torch/bin/python3.10    10652MiB |
|    7   N/A  N/A   1939484      C   ...envs/torch/bin/python3.10    10652MiB |
+-----------------------------------------------------------------------------+

here is the code when I load the model and adaptor:

config = PeftConfig.from_pretrained(args.parent_model_path)
model = AutoModelForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    torch_dtype=compute_dtype,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=compute_dtype,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
    ),
    trust_remote_code=True,
)
model = prepare_model_for_kbit_training(model)
model = PeftModel.from_pretrained(model, 
                                          args.parent_model_path, 
                                          is_trainable=True)


# ... some further setup


training_args = SFTConfig(
    output_dir=os.path.join("results", args.train_name, "checkpoint"),
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=args.per_device_batch_size,
    per_device_eval_batch_size=args.per_device_batch_size,
    gradient_accumulation_steps=4,
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    learning_rate=2e-4,
    logging_steps=10,
    evaluation_strategy="steps",
    eval_steps=500,
    save_strategy="steps",
    save_steps=1000,
    max_seq_length=args.max_seq_length,
    logging_first_step=True,
    push_to_hub=False,
    ddp_find_unused_parameters=False,
    load_best_model_at_end=False,
    report_to="wandb",
    run_name=args.train_name,
    warmup_steps=100, 
    warmup_ratio=0.1,  
)


trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    max_seq_length=args.max_seq_length,
    tokenizer=tokenizer,
    data_collator=collator,
    formatting_func=formatting_func,
)


trainer = accelerator.prepare(trainer)
trainer.train()

Normally this should not happen, the LoRA adapters should be loaded to the same device as the base layer they’re attached to.

After loading the LoRA adapter, could you please loop through the weights and if the name contains “lora_A” or “lora_B”, print the device of the weight? That way, we can quickly check if all LoRA weights are indeed loaded to GPU 0 or not.

Hi @BenjaminB ,
Thanks so much for your prompt response and your suggestions. Really appreciate that.

I added a code for printing the weights containing lora_A and lora_B to inspect the actual weights on each gpu:

if args.additive_tuning:
    # keep tuning after the previous LoRA adapter tuning
    config = PeftConfig.from_pretrained(args.parent_model_path)
    model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name_or_path,
        torch_dtype=compute_dtype,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_use_double_quant=True,
        ),
        trust_remote_code=True,
        device_map=None
    )
    model = prepare_model_for_kbit_training(model)
    model = PeftModel.from_pretrained(model, 
                                      args.parent_model_path, 
                                      is_trainable=True)
    for name, param in model.named_parameters():
            if 'lora_A' in name or 'lora_B' in name:
                print(f"LoRA weight {name} is on device: {param.device}")   

And you are correct. I indeed saw all 8 cuda devices have the same copy of lora adaptor weights. here is part of the output:


LoRA weight base_model.model.model.layers.1.self_attn.k_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.self_attn.k_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.self_attn.v_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.self_attn.v_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.self_attn.o_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.self_attn.o_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.mlp.gate_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.mlp.gate_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.mlp.up_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.mlp.up_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.mlp.down_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.1.mlp.down_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.q_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.q_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.k_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.k_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.v_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.v_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.o_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.self_attn.o_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.mlp.gate_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.mlp.gate_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.mlp.up_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.mlp.up_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.mlp.down_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.2.mlp.down_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.q_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.q_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.k_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.k_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.v_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.v_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.o_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.self_attn.o_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.mlp.gate_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.mlp.gate_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.mlp.up_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.mlp.up_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.mlp.down_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.3.mlp.down_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.4.self_attn.q_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.4.self_attn.q_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.4.self_attn.k_proj.lora_A.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.4.self_attn.k_proj.lora_B.default.weight is on device: cuda:4
LoRA weight base_model.model.model.layers.4.self_attn.v_proj.lora_A.default.weight is on device: cuda:4
...
LoRA weight base_model.model.model.layers.24.self_attn.q_proj.lora_A.default.weight is on device: cuda:0
LoRA weight base_model.model.model.layers.24.self_attn.q_proj.lora_B.default.weight is on device: cuda:0
LoRA weight base_model.model.model.layers.24.self_attn.k_proj.lora_A.default.weight is on device: cuda:0
LoRA weight base_model.model.model.layers.24.self_attn.k_proj.lora_B.default.weight is on device: cuda:0
LoRA weight base_model.model.model.layers.24.self_attn.v_proj.lora_A.default.weight is on device: cuda:0

... Not gonna show all of the output for the sake of words limit per post, but I indeed saw all the cuda devices were included

However, i still see more process load on GPU 0 than the rest of the GPUs. Would you please suggest any other possible problems I may have created to cause this?

It is very hard to tell based on the given information. You mentioned above that more GPU memory is used on GPU 0, now you mention process load. So is it both? At the same time? How much extra memory are we talking about and how long is the processing?

What could also help debugging is if you can figure out at which step this extra load occurs. You showed some code but it’s not clear which step exactly is responsible.