Errors when using gradient accumulation with FSDP + PEFT LoRA + SFTTrainer

Hello!

I’ve been running into this weird error with my distributed training setup. I’m trying to fine-tune Llama 3 8B Instruct on 2 (or more but 2 for now) A6000s. I’m using FSDP, PEFT LoRA, and the SFTTrainer from the trl library. When using mixed precision (bf16) I get the following error on line 380 in torch.optim.adamw.py:

RuntimeError: expected dtype float for *end* but got dtype c10::BFloat16

The relevant line:

# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)

The tensor data types:

exp_avg.dtype == torch.float32 
grad.dtype == torch.bfloat16 
beta1 == float

When I remove mixed precision, I get a different error on the same line:

The size of tensor a (16384) must match the size of tensor b (4096) at non-singleton dimension 1

The tensor shapes:

exp_avg.shape == torch.Size([16384])
grad.shape == torch.Size([8, 4096])

When I set use_orig_params = False, I get the same error with grad.shape == torch.Size([32768]).

When I set gradient_accumulation_steps=1, these problems all go away. Why is this the case? What can I change so my training code works with gradient accumulation?

Here’s all the parameters I’m using to accelerate launch:

--use_fsdp 
--mixed-precision=bf16
--num-machines=1 
--rdzv-backend=static 
--same-network 
--main-training-function=main 
--machine-rank=0 
--num-processes=2 
--gpu-ids=0,1
--fsdp-auto-wrap-policy=TRANSFORMER_BASED_WRAP 
--fsdp-backward-prefetch=BACKWARD_PRE 
--fsdp-sharding-strategy=FULL_SHARD 
--fsdp-state-dict-type=FULL_STATE_DICT 
--fsdp-activation-checkpointing=False 
--fsdp-sync-module-states=True 
--fsdp-use-orig-params=True 
--fsdp-cpu-ram-efficient-loading=True 
--fsdp-forward-prefetch=False 
--fsdp-offload-params=True
<train.py>
--per_device_train_batch_size=1 
--num_train_epochs=1
--gradient_accumulation_steps=8 
--gradient_checkpointing=True 
--learning_rate=0.0002 
--report_to=none 
--optim=adamw_torch 
--max_seq_length=4096 
--lr_scheduler_type=constant 
--logging_steps=1
--lora_r=8 
--lora_alpha=32 
--lora_dropout=0.1 
--model_name=meta-llama/Meta-Llama-3-8B-Instruct 

Relevant part of the training code:

parser = HfArgumentParser((TrainingArguments, ScriptArguments))  # type: ignore

sft_config, args = parser.parse_args_into_dataclasses()
sft_config.remove_unused_columns = False  # Necessary for the collator to have access to traj metadata
sft_config.gradient_checkpointing_kwargs = args.g_c_kwargs
sft_config.dataset_text_field = "text"

tokenizer = AutoTokenizer.from_pretrained(args.model_name)

dataset, model, peft_config = setup_dataset_and_model(args, format_dataset, tokenizer)

trainer = SFTTrainer(
        model=model,
        tokenizer=tokenizer,
        train_dataset=dataset,
        args=sft_config,
        peft_config=peft_config,
        data_collator=collator,
        max_seq_length=args.max_seq_length,
    )

# Remove the columns that are not needed or it will cause errors, as training will try to cast these strings to tensors
trainer.train_dataset = trainer.train_dataset.remove_columns(["text", "messages"])  # type: ignore

# handle PEFT+FSDP case
print_trainable_parameters(trainer.model)
if getattr(trainer.accelerator.state, "fsdp_plugin", None):
    from peft.utils.other import fsdp_auto_wrap_policy

    fsdp_plugin = trainer.accelerator.state.fsdp_plugin
    fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)

# Train the model
trainer.train()  # type: ignore
1 Like

I have the same issue. Were you able to solve it?

I see this reported here as well: tensor size mismatch with larger gradient_accumulation_steps and fewer training data · Issue #25695 · huggingface/transformers · GitHub