I am trying to fine-tune the llama3 model using SFT (with PEFT LoRa). While I have achieved the desired performance, the fine-tuning speed was very slow. Then, I specified bf16 = True
in SFTConfig
, and the fine-tuning speed improved by about 2 times, but the performance dropped to about 40% of the original.
I suspect that the difference between having the bf16
option and not having it is that the training is performed with float32
or float16
. Similarly, I changed the tf32
option in SFTConfig
to true, and although the training speed improved similarly, the performance did not return to what it was without any options. I have currently rolled back the code, but I would like to understand why this happened.
I would appreciate it if you could let me know if I made any mistakes or provide any advice.
A portion of my code is below:
compute_dtype = getattr(torch, bnb_args.bnb_4bit_compute_dtype)
bnb_config = BitsAndBytesConfig(
load_in_4bit=bnb_args.load_in_4bit,
load_in_8bit=bnb_args.load_in_8bit,
bnb_4bit_quant_type=bnb_args.bnb_4bit_quant_type,
bnb_4bit_compute_dtype=compute_dtype # usually torch.bfloat16
)
sft_config = SFTConfig(
dataset_text_field=sft_args.dataset_text_field,
max_seq_length=sft_args.max_seq_length,
output_dir=sft_args.sft_output_dir,
# bf16 = True if compute_dtype is torch.bfloat16 else False, # performance incrased when removing this line
# tf32= True if compute_dtype is torch.float32 else False
)
dataset = load_dataset("text", data_files=data_path,
split=f"train[:{data_ratio}%]"
)
# split="train")
base_model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=compute_dtype,
quantization_config=bnb_config if (bnb_config.load_in_4bit or bnb_config.load_in_8bit) else None,
device_map={"": Accelerator().local_process_index},
low_cpu_mem_usage=True,
)
base_model.config.use_cache = False
base_model = prepare_model_for_kbit_training(base_model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
peft_config = LoraConfig(
r=lora_args.r,
lora_alpha=lora_args.lora_alpha,
target_modules= [ "q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
], # find_all_linear_names(base_model),
lora_dropout=lora_args.lora_dropout,
bias=lora_args.bias,
task_type=lora_args.task_type,
)
base_model = get_peft_model(base_model, peft_config)
trainer = SFTTrainer(
model=base_model,
train_dataset=dataset,
dataset_text_field="text",
tokenizer=tokenizer,
max_seq_length=sft_args.max_seq_length,
formatting_func=formatting_prompts_func,
packing=False,
args=training_args,
)
trainer.train(resume_from_checkpoint=args.checkpoint_path)