Reward becomes nan when switching from full precision to fp16 for gemma3-12b-it

I am training gemma3-12b-it on a standard preference dataset. When I accelerate launch train.py on gemma3-12b-it in full precision, the training curve looks reasonable. However, if I switch from full precision to fp16, suddenly the logging shows loss=0, grad_norm=0, reward=nan.... Are multimodal models restricted to full precision training?

from datasets import load_dataset
from trl import RewardTrainer, RewardConfig, DPOConfig, DPOTrainer
from peft import LoraConfig, TaskType
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "gemma-3-12b-it"
model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_dataset = load_dataset("json", data_files="training_data.json", split="train")
tokenizer.pad_token = tokenizer.eos_token

def process_training_data(example):
    example["prompt"] = example.pop("input")
    example['rejected'] = example['rejected'][0]
    return example
train_dataset = train_dataset.map(process_training_data)

training_args = DPOConfig(
    dataloader_pin_memory=False,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    logging_steps=10,
    # fp16=True
)
training_args.optimize_cuda_cache=True

peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=[
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj",
    "lm_head",
    ]
)

trainer = DPOTrainer(model=model,
                     args=training_args,
                     processing_class=tokenizer,
                     train_dataset=train_dataset,
                     peft_config=peft_config)
trainer.train()
1 Like

Perhaps mixed precision training issue?

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.

Could you check the dtype of the LoRA parameters after model initialization? Specifically, are they float16 or float32?

1 Like