OOM Error using PPO Trainer to LoRa-tune 4-bit Llama-3-8B Model

As per the standard for PPO Training (which is to do supervised-fine tuning before running the PPO Algorithm) I did a QLoRa fine-tuning of the Llama-3-8B instruct model using my own custom data and the SFT Trainer. I then merged the LoRa adapers and pushed this model to the HF hub in 4-bit.

For the PPO Training step, I initialized my model like this (the Lora config and quantization config are defined elsewhere before this):

model_id = “path-to-my-model”
model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_id,
peft_config=lora_config,
device_map={“”: 0},
quantization_config=bnb_config,
)

Then I run my PPO Training loop (using a custom Pytorch dataloader because the PPO one does not support dynamic padding when streaming a large dataset):

from tqdm import tqdm

Training parameters

epochs = 4
generation_kwargs = {
“min_length”: -1,
“top_k”: 0.0,
“top_p”: 1.0,
“do_sample”: True,
“pad_token_id”: tokenizer.pad_token_id,
“max_new_tokens”: 2560, #since max input is 2048 want to give some space for more
}

Training loop

for epoch in tqdm(range(epochs), "epoch: "):
batchnum = 0
numsave = 1
for batch in tqdm(dataloader): #swtiched from “ppo_trainer.dataloader” to dataloader which is defined in cell above
batchnum += 1
query_tensors = batch[“input_ids”] #‘input_ids’ is just tokenized query. List of integers

    # Get response from SFTModel
    response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs) 
    batch["response"] = [tokenizer.decode(r.squeeze(), skip_special_tokens=True) for r in response_tensors]

    # Decode and compute rewards (batch"query" was never encoded so just use that)
    rewards = [reward_function(query, response) for query, response in zip(batch["query"], batch["response"])] #ADDD BACK

    # Run PPO step

    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)
    if batchnum == 25000:
        ppo_trainer.save_model(f"/content/drive/MyDrive/Worksheets AI/PPO_Llama_epoch_{epoch}_{numsave}")
        numsave = numsave + 1
        batchnum = 0

# Save the trained model after each epoch
ppo_trainer.save_pretrained(f"/content/drive/MyDrive/Worksheets AI/PPO_Llama_epoch_{epoch}")

I made the reward model myself. It is pretty extensive, and relies on a BERT model to first assign an integer rating to the response, but the BERT model is on CPU and as far as I could tell the PPO algorithm does not compute gradients based on how the reward is calculated (unless I am wrong in which case I think I know where the problem is).

My question is why, with the 40 gb of RAM that Google Colab’s A100 gives you, am I stil getting an OOM error? My GPU memory is at 7gb basically the whole time until the ppo_trainer.step() line, where it skyrockets to 40 and throws the error.

Also, here is the google colab I am using that has more details: