How to pass input to a Reward Model and make sense of its output?

I am in the process of doing RLHF on LLaMA 2 13b. One of the steps is making a Reward Model.
Using a custom dataset of texts that are better and comparatively not so better, I made a dataset. Lets say that it is very similar to the example thats there in the official TRL library - “chosen” and “rejected” - Reward Modeling

The Reward Model was successfully made (the eval accuracy as seen in the logs was about 67% but thats a story for a different day).

Now what I would like to do is to actually pass an input and see the output of the Reward model.

However I can’t seem to make any sense of what the reward model outputs.

For example: I tried to make the input as follows -

chosen = "This is the chosen text."
rejected = "This is the rejected text."
test = {"chosen": chosen, "rejected": rejected}

Then I try -

rewards_chosen = model(
            **tokenizer(chosen, return_tensors='pt')
        ).logits
print('reward chosen is ', rewards_chosen)

rewards_rejected = model(
           **tokenizer(rejected, return_tensors='pt')
        ).logits

print('reward rejected is ', rewards_rejected)
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected).mean()
print(loss)

Printing loss wasn’t helpful. I mean I do not see any trend even if I switch rewards_chosen and rewards_rejected in the formula.

Also the outputs did not yield any big insights. I do not understand how to make sense of rewards chosen and rewards rejected. I have had examples where rewards chosen is bigger and then in other when its smaller (shouldn’t it always be higher?).

I tried rewards_chosen>rewards_rejected but that is also not helpful since it outputs tensor([[ True, False]])

How do we figure out what is the meaning of the output of the reward model, how do we know what string it is preferring?

1 Like

Did you ever find the answer? For some reason, I am stuck with the 67% accuracy problem too (after extensive lora hyperparam tuning) and maybe that is the reason for the outputs not showing patterns.