Finetune Llama with PPOTrainer

I am trying to finetune Llama with PPOTrainer class of TRL, a similar tutorial is used to finetune gpt2 on IMDB dataset.
But I keep getting this error when logging to wandb - ValueError: autodetected range of [nan, nan] is not finite

Also many ppo related values such as ‘ppo/loss/policy’,‘ppo/loss/value’, ‘ppo/loss/total’, ‘ppo/policy/entropy’, etc are nan values.
Refer this notebook(a copy of the tutorial notebook but with a different model) for the error

Hi Harshvir! I am encountering the exactly same situation while I am testing with a small gpt-neo-x model, did you already solve this problem? I would appreciate it very much if you could share the solution! Thanks!

Solution found! In my case keeping the torch_dtype for both the base model and the ppo model as torch.bfloat16 solved the problem.

Example like this:

model = AutoModelForCausalLM.from_pretrained(
    config.model_name, 
    torch_dtype=torch.bfloat16,
    device_map=device_map,
    )

ppo_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model,
    torch_dtype=torch.bfloat16,
    is_trainable=True
    )