Invalid key for dataset -- is this a bug with Trainers or with my code?

Code for replication:

import torch
from torch import nn
from transformers import AutoTokenizer, AutoModel, LlamaTokenizer, DataCollatorWithPadding
import pandas as pd
from datasets import load_dataset, Dataset, DatasetDict
from functools import partial
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate

class RewardModel(nn.Module):

    def __init__(self, model):
        self.language_model = model
        self.fc = nn.Linear(self.language_model.config.hidden_size, 1)

    def forward(self, **args):
        outputs = self.language_model(**args)
        last_hidden_state = outputs.last_hidden_state
        reward = self.fc(last_hidden_state) # (batch_size, seq_len, 1)
        reward = reward.squeeze(-1) # (batch_size, seq_len)
        reward = reward[:,-1] # takes reward at last seq pos (batch_size)
        return reward

pretrained_model_name = "decapoda-research/llama-7b-hf"
model = AutoModel.from_pretrained(pretrained_model_name)
reward_model = RewardModel(model)

for param in reward_model.parameters(): # all the requires grads are false
    param.requires_grad = False
for param in reward_model.fc.parameters(): # except the last layer
    param.requires_grad = True

tokenizer = LlamaTokenizer.from_pretrained(pretrained_model_name)
if tokenizer.pad_token is None:

tokenized_dataset = load_dataset('notrichardren/hh-rlhf-tf') # has columns 'input_ids', 'attention_mask', 'labels'
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

class LabelPopTrainer(Trainer):
    def compute_loss(self,model,inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs).flatten()
        loss = torch.nn.functional.cross_entropy(outputs, labels.half())
        return (loss, outputs) if return_outputs else loss

args = TrainingArguments("test-trainer", 
                        num_train_epochs = 3,
                        per_device_train_batch_size = 4,
                        logging_strategy = "steps",
                        logging_steps = 3,

trainer = LabelPopTrainer(



I get the error that “Invalid key: 222838 is out of bounds for size 0” for the dataset. However, tokenized_dataset[“train”] (which is the one passed to the trainer) is:

    features: ['labels', 'input_ids', 'attention_mask'],
    num_rows: 224104

The problem was that in this context, removed_unused_columns = True seems to delete the entire dataset for a custom loss func.

This seems to be a bug and not the default behavior, so I reported it to huggingface/transformers.