Finetune Llama with PPOTrainer

Solution found! In my case keeping the torch_dtype for both the base model and the ppo model as torch.bfloat16 solved the problem.

Example like this:

model = AutoModelForCausalLM.from_pretrained(
    config.model_name, 
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    )

ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model,
    torch_dtype=torch.bfloat16,
    is_trainable=True
    )
1 Like