PPOTrainer: Output generated during training different than that during inference

Hello,

I’m using the PPOTrainer library to do LM fine-tuning and having trouble debugging an issue where a trained model generates gibberish. At first, I thought this was due to the training setup, but when I looked at the outputs generated every few steps using PPOTrainer.generate() during training they seemed fine. However, when I load the model trained for the same amount of steps and do offline inference on the same set of input prompts, I somehow end up with gibberish. I’m guessing there is some subtle difference between how generation is done during training and inference that I’m currently missing.

For the training code, I mostly followed the example given in StackLLaMa. If it’s helpful, here is a snippet of the training code:

model_kwargs = dict(
    revision=model_args.model_revision,
    trust_remote_code=model_args.trust_remote_code,
    use_flash_attention_2=model_args.use_flash_attention_2,
    torch_dtype=torch_dtype,
    use_cache=False if training_args.gradient_checkpointing else True,
    device_map=get_kbit_device_map(),
    quantization_config=get_quantization_config(model_args),
)

model = AutoModelForCausalLMWithValueHead.from_pretrained(
    model_args.model_name_or_path,
    load_in_8bit=True,
    peft_config=get_peft_config(model_args),
    **model_kwargs,
)
...
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=model,
    ref_model=None,
    tokenizer=tokenizer,
    dataset=train_dataset,
    data_collator=collator,
)
...

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "max_new_tokens": training_args.output_max_length,
    "pad_token_id": tokenizer.eos_token_id,
}
...
for step, batch in tqdm(
    enumerate(ppo_trainer.dataloader),
    desc="Train step: ",
    disable=not accelerator.is_local_main_process,
    total=min(len(ppo_trainer.dataloader), ppo_config.steps),
):
    if step >= ppo_config.steps: break

    # Do a small sanity eval for sanity check.
    if step % 5 == 0:
        small_eval(step, ppo_trainer, rmodel, eval_dataloader)

    query_tensors = batch["input_ids"]

    # Get responses from the base model.
    response_tensors = ppo_trainer.generate(
        query_tensors,
        return_prompt=False,
        **generation_kwargs,
    )
    batch["response"] = tokenizer.batch_decode(response_tensors)

    # Compute the reward scores.
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    with torch.no_grad():
        rmodel_inputs = rmodel_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
        rewards = [r.squeeze().to(dtype=torch.float) for r in rmodel(**rmodel_inputs)]

    # Run PPO step.
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

The offline inference code is essentially the few lines of code given in Use model after training.

I’m curious if anyone else with more experience with PPOTrainer had a similar issue. I’d very much appreciate any input on this.

Also, on a somewhat related note, I’m curious about the correct way to save checkpoints and resume training with PPOTrainer. I’ve noticed that the API differs somewhat from SFTTrainer, which seems to rely on the transformers’ Trainer for checkpointing in a more conventional manner.

Thanks!

Testing this out for the first time and finding the same issue as well. During DPO training, my loss decreased and rewards slowly increased to ~90%. When I load the model and pass in the same prompts, it outputs gibberish. Not sure what went wrong :slight_smile: Following.