Unstable PPO training: Highly negative KL divergence and highly positive average ratio of batch on LLMs

I have an APS (Answer Paragraph Selection) model which takes an input of question_text + paragraph_text and predicts a value between 0 - 1 (logits → sigmoid) that represents whether the given paragraph can answer the question or not.

This APS model has been finetuned separately on a large dataset.

I was thinking, what if I can use this APS model as a reward model in the PPO training. I removed the sigmoid on APS model to get the logits. I implemented this by building upon this.

The prompt is something like, "Answer the question based on the context.\n\n##Q: <question>\n\n##C: <context>\n\n#A:"

Only the generated response and the original question are fed to APS model and logits are captured.

But while running the pipeline, I observed that for long context it easily results in CUDA OOM (even on 4 A100 with gemma2b). When I use a shorter context (filtered through another APS model), then I get

/usr/l.../ppo_trainer.py:1246: UserWarning: The average ratio of batch (2373949488889856.00) exceeds threshold 10.00. Skipping batch.

/usr/l...er/ppo_trainer.py:1313: UserWarning: KL divergence is starting to become negative: -5529.82 - this might be a precursor for failed training. sometimes this happens because the generation kwargs are not correctly set. Please make sure that the generation kwargs are set correctly, or review your training hyperparameters.

Then I get

ValueError: autodetected range of [nan, nan] is not finite

while logging the stats.

I have tried changing the data type of ppo model , setting a minimum length for output , explicitly setting eos token id , and changing grad accumulation steps,

But nothing worked so far. Even if I disable logging to bypass the error… I get

RuntimeError: probability tensor contains either `inf`, `nan` or element < 0 while generating the text.

How can I solve this?

I have tried lower learning rate as well. That just delays the inevitable.

1 Like