Confusing (and possibly misleading) PPO Trainer Code from TRL API Doc Tutorial

The usage of the variable “epoch” is confusing and possibly misleading in the TRL Doc API

From this website:

It uses this code as an example of using PPO_Trainer:

from tqdm import tqdm

for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #### Get response from SFTModel
    response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute reward score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = reward_model(texts)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

#### Save model
ppo_trainer.save_model("my_ppo_model")

Here one infers that the dataloader addresses the epoch, meaning that if we enumerate the dataloader then we get the epoch # (1st epoch, 2nd epoch, …). This means that len(dataloader) should equal to epoch, which seems to not be the case and dataloader should corresponds to 1 epoch only instead.

The PPO config used in the doc here is:

from trl import PPOConfig

config = PPOConfig(
    model_name="gpt2",
    learning_rate=1.41e-5,
)

From this we infer that the batch_size is 256, the default value; the default number of epochs is 4.

In the TRL Repo’s PPO example code, dataloader corresponds to the data for 1 epoch:

  for epoch in range(2):
    for batch in tqdm(ppo_trainer.dataloader):
         [...]

This is really confusing: in one case we need a outer loop for epoch, while in the other case only 1 loop is needed, which loops through the dataloader and the dataloader covers the data for ALL epochs.

The example from the repo (with the outer loop) seems like the correct one, making the example from the Doc confusing and misleading. Am I missing something or is “for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):” misleading?

perhaps a better name for it can be `batch_id’, the ith batch in the dataloader in one epoch.

Hi,

Thanks for flagging. Feel free to open an issue/pull request on the TRL repository regarding this