The role of the bf16 arguments in SFTConfig

I am trying to fine-tune the llama3 model using SFT (with PEFT LoRa). While I have achieved the desired performance, the fine-tuning speed was very slow. Then, I specified bf16 = True in SFTConfig, and the fine-tuning speed improved by about 2 times, but the performance dropped to about 40% of the original.

I suspect that the difference between having the bf16 option and not having it is that the training is performed with float32 or float16. Similarly, I changed the tf32 option in SFTConfig to true, and although the training speed improved similarly, the performance did not return to what it was without any options. I have currently rolled back the code, but I would like to understand why this happened.

I would appreciate it if you could let me know if I made any mistakes or provide any advice.

A portion of my code is below:


compute_dtype = getattr(torch, bnb_args.bnb_4bit_compute_dtype)
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=bnb_args.load_in_4bit,
        load_in_8bit=bnb_args.load_in_8bit,
        bnb_4bit_quant_type=bnb_args.bnb_4bit_quant_type,
        bnb_4bit_compute_dtype=compute_dtype # usually torch.bfloat16
    )

  sft_config = SFTConfig(
      dataset_text_field=sft_args.dataset_text_field,
      max_seq_length=sft_args.max_seq_length,
      output_dir=sft_args.sft_output_dir,
      # bf16 = True if compute_dtype is torch.bfloat16 else False, # performance incrased when  removing this line 
      # tf32= True if compute_dtype is torch.float32 else False
  )

dataset = load_dataset("text", data_files=data_path,
                           split=f"train[:{data_ratio}%]"
                           )
                           # split="train")

base_model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=compute_dtype,
            quantization_config=bnb_config if (bnb_config.load_in_4bit or bnb_config.load_in_8bit) else None,
            device_map={"": Accelerator().local_process_index},
            low_cpu_mem_usage=True,
        )
base_model.config.use_cache = False
base_model = prepare_model_for_kbit_training(base_model)

tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

peft_config = LoraConfig(
        r=lora_args.r,
        lora_alpha=lora_args.lora_alpha,
        target_modules= [ "q_proj",
                          "k_proj",
                          "v_proj",
                          "o_proj",
                          "gate_proj",
                          "up_proj",
                          "down_proj",
                          ],  # find_all_linear_names(base_model),
        lora_dropout=lora_args.lora_dropout,
        bias=lora_args.bias,
        task_type=lora_args.task_type,
    )

base_model = get_peft_model(base_model, peft_config)

trainer = SFTTrainer(
            model=base_model,
            train_dataset=dataset,
            dataset_text_field="text",
            tokenizer=tokenizer,
            max_seq_length=sft_args.max_seq_length,
            formatting_func=formatting_prompts_func,
            packing=False,
            args=training_args,
        )

trainer.train(resume_from_checkpoint=args.checkpoint_path)