Solution found! In my case keeping the torch_dtype
for both the base model and the ppo model as torch.bfloat16
solved the problem.
Example like this:
model = AutoModelForCausalLM.from_pretrained(
config.model_name,
torch_dtype=torch.bfloat16,
device_map=device_map,
)
ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
model,
torch_dtype=torch.bfloat16,
is_trainable=True
)