Hello!
I’ve been running into this weird error with my distributed training setup. I’m trying to fine-tune Llama 3 8B Instruct on 2 (or more but 2 for now) A6000s. I’m using FSDP, PEFT LoRA, and the SFTTrainer
from the trl
library. When using mixed precision (bf16) I get the following error on line 380 in torch.optim.adamw.py
:
RuntimeError: expected dtype float for *end* but got dtype c10::BFloat16
The relevant line:
# Decay the first and second moment running average coefficient
exp_avg.lerp_(grad, 1 - beta1)
The tensor data types:
exp_avg.dtype == torch.float32
grad.dtype == torch.bfloat16
beta1 == float
When I remove mixed precision, I get a different error on the same line:
The size of tensor a (16384) must match the size of tensor b (4096) at non-singleton dimension 1
The tensor shapes:
exp_avg.shape == torch.Size([16384])
grad.shape == torch.Size([8, 4096])
When I set use_orig_params = False
, I get the same error with grad.shape == torch.Size([32768])
.
When I set gradient_accumulation_steps=1
, these problems all go away. Why is this the case? What can I change so my training code works with gradient accumulation?
Here’s all the parameters I’m using to accelerate launch
:
--use_fsdp
--mixed-precision=bf16
--num-machines=1
--rdzv-backend=static
--same-network
--main-training-function=main
--machine-rank=0
--num-processes=2
--gpu-ids=0,1
--fsdp-auto-wrap-policy=TRANSFORMER_BASED_WRAP
--fsdp-backward-prefetch=BACKWARD_PRE
--fsdp-sharding-strategy=FULL_SHARD
--fsdp-state-dict-type=FULL_STATE_DICT
--fsdp-activation-checkpointing=False
--fsdp-sync-module-states=True
--fsdp-use-orig-params=True
--fsdp-cpu-ram-efficient-loading=True
--fsdp-forward-prefetch=False
--fsdp-offload-params=True
<train.py>
--per_device_train_batch_size=1
--num_train_epochs=1
--gradient_accumulation_steps=8
--gradient_checkpointing=True
--learning_rate=0.0002
--report_to=none
--optim=adamw_torch
--max_seq_length=4096
--lr_scheduler_type=constant
--logging_steps=1
--lora_r=8
--lora_alpha=32
--lora_dropout=0.1
--model_name=meta-llama/Meta-Llama-3-8B-Instruct
Relevant part of the training code:
parser = HfArgumentParser((TrainingArguments, ScriptArguments)) # type: ignore
sft_config, args = parser.parse_args_into_dataclasses()
sft_config.remove_unused_columns = False # Necessary for the collator to have access to traj metadata
sft_config.gradient_checkpointing_kwargs = args.g_c_kwargs
sft_config.dataset_text_field = "text"
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
dataset, model, peft_config = setup_dataset_and_model(args, format_dataset, tokenizer)
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
args=sft_config,
peft_config=peft_config,
data_collator=collator,
max_seq_length=args.max_seq_length,
)
# Remove the columns that are not needed or it will cause errors, as training will try to cast these strings to tensors
trainer.train_dataset = trainer.train_dataset.remove_columns(["text", "messages"]) # type: ignore
# handle PEFT+FSDP case
print_trainable_parameters(trainer.model)
if getattr(trainer.accelerator.state, "fsdp_plugin", None):
from peft.utils.other import fsdp_auto_wrap_policy
fsdp_plugin = trainer.accelerator.state.fsdp_plugin
fsdp_plugin.auto_wrap_policy = fsdp_auto_wrap_policy(trainer.model)
# Train the model
trainer.train() # type: ignore