PPO using TRL: optimal strategy for reward calculation?

Hi everyone,

Something I’ve been wondering recently and I’d value some input.

I’ve been working with the trl library to fine-tune various decoder-only LLMs via RLHF. During the PPO loop, I’ll collect the rewards using something like:

raw_rewards = ppo_trainer.model.compute_reward_score(input_ids, attention_masks)

Which will returns the class logits from the previously trained AutoModelForSequenceClassification model. We then need to turn these into a reward. I’ve seen different approaches to this, for example taking the first element of the logit (see here) or taking the last element of the logit (see here).

Examining the code for the RewardTrainer we can see how the loss function is constructed (line 238) :

loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()

If we denote the chosen and rejected rewards by the tuples (c1, c2) and (r1, r2) then we can write the above as:

\begin{align} \mathcal{L} &= -\frac{1}{2} \left[ log \sigma \left(c_1-r_1\right) + log \sigma \left(c_2 - r_2\right) \right] \\ &= \frac{1}{2} \left[ log \left(e^{-r_1+c_1} + 1\right) + log \left(e^{-r_2+c_2} + 1\right) \right] \end{align}

And see that the loss function is going to force weight updates that should cause both elements of the reward tuple to be driven towards a high score (for a chosen input) or a low score (for a rejected input). (Hence, presumably, why I have seen code which does both.)

So, here’s the first question: should we be indifferent as to which element of the reward logit we hand to the PPO algorithm, or would it be even better to combine them - eg to sum them?

And here’s the second question: the PPOTrainer allows us to apply batch-scaling and reward clipping to help stabilise gradients updates for the policy model, but why not just pass the reward logits through a logsigmoid() function first (mirroring the loss function)? That, after all, would mirror what the loss function in the RewardTrainer is doing.

Are there any theoretical reasons or implementation details that bear upon the above?

Cannot edit the post so correcting the typo in the second eqn above:

\frac{1}{2} \left[ log \left(e^{-c_1+r_1} + 1\right) + log \left(e^{-c_2+r_2} + 1\right) \right]