How to use huggingface HF trainer train with custom collate function?

I have some custom data set with custom table entries and wanted to deal with it with a custom collate. But it didn’t work when I pass a collate function I wrote (that DOES work on a individual dataloader e.g., see python - How does one create a pytorch data loader with a custom hugging face data set without having errors? - Stack Overflow or python - How does one create a pytoch data loader using an interleaved hugging face dataset? - Stack Overflow) . It just doesn’t work with HF trianer.


from pathlib import Path
# token = open(Path('~/data/hf_token.txt').expanduser()).read().strip()
token = None
batch_size = 8

# -- AF now
from datasets import load_dataset
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if tokenizer.pad_token_id is None:
  tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
model =

# -- Get batch from dataset
from datasets import load_dataset
# path, name = 'brando/debug1_af', 'debug1_af'
path, name = 'brando/debug0_af', 'debug0_af'
# train_dataset = load_dataset(path, name, streaming=True, split="train", token=token).with_format(type="torch")
# eval_dataset = load_dataset(path, name, streaming=True, split="test", token=token).with_format(type="torch")
# batch = dataset.take(1)
# column_names = next(iterbatch).keys()
# print(f'{column_names=}')

# -- Compute max steps (I think we should try to do this for real experiments such that the number of tokens is the same in all training runs for fair experiments, todo: ask Sudharsan or online, for now just make streaming=False)
train_dataset = load_dataset(path, name, streaming=False, split="train", token=token).with_format(type="torch")  # hack to get dataset size
eval_dataset = load_dataset(path, name, streaming=False, split="test", token=token).with_format(type="torch") # hack to get dataset size
per_device_train_batch_size = batch_size
num_epochs = 1
max_steps = (len(train_dataset) // per_device_train_batch_size) * num_epochs

# -- Get trainer
def collate_tokenize(data):
    text_batch = [f'informal statement {example["generated informal statement"]} formal statement {example["formal statement"]}' for example in data]
    tokenized = tokenizer(text_batch, padding='longest', max_length=128, truncation=True, return_tensors='pt')
    return tokenized

from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir=Path('./results').expanduser(),          # output directory
    max_steps=max_steps,             # max_steps
    per_device_train_batch_size=per_device_train_batch_size,   # batch size per device during training
    per_device_eval_batch_size=batch_size,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=Path('./logs').expanduser(),            # directory for storing logs
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=eval_dataset,             # evaluation dataset
    data_collator = collate_tokenize,


/usr/local/lib/python3.10/dist-packages/transformers/ FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
IndexError                                Traceback (most recent call last)
<ipython-input-2-4403554fc52d> in <cell line: 63>()
     61     data_collator = collate_tokenize,
     62 )
---> 63 trainer.train()
     64 print('Done!\a')

11 frames
/usr/local/lib/python3.10/dist-packages/datasets/formatting/ in _check_valid_index_key(key, size)
    524     if isinstance(key, int):
    525         if (key < 0 and key + size < 0) or (key >= size):
--> 526             raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
    527         return
    528     elif isinstance(key, slice):

IndexError: Invalid key: 12 is out of bounds for size 0

why? How to fix?

You can avoid this error by passing remove_unused_columns=False to TrainingArguments, but a cleaner solution is to use map to tokenize the dataset before passing it to the Trainer (instead of tokenizing lazily).

After this change, you should get the “The model did not return a loss from the inputs …” error, which you can fix by returning the labels column in the collate function (equal to input_ids).

(DataCollatorForLanguageModelling handles this automatically, so it’s better to perform the tokenization in map, and then use this collator as a data_collator, as explained in our NLP course)

@mariosasko is this the code you had in mind?

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True)

train_dataset = load_dataset(path, name, split="train")
train_dataset =, batched=True)

eval_dataset = load_dataset(path, name, split="test") 
eval_dataset =, batched=True)

def collate_function(examples):
    input_ids = torch.stack([example["input_ids"] for example in examples])
    labels = torch.stack([example["input_ids"] for example in examples])
    return {"input_ids": input_ids, "labels": labels}

training_args = TrainingArguments(...)

trainer = Trainer(


also I’m realizing that perhaps this is a mis understanding on my part…but eval_dataset (or any HF dataset) is usually an iterable So the line:

eval_dataset =, batched=True)

is always applied as needed (lazily) i.e., the code NEVER maps and tokenizes until we get a batch (e.g., in the collate or someone/thing does next(iter(mapped_dataset))). Is this right?

If that is right I don’t think I even need the collate function anymore because I can simply create the prerocess function and then pass that, batch=True) and that won’t be called until we actually get a new batch (through a python generator).


tldr; I was worried that, batched=True) eagerly did the preprocessing of the entire data set. (IterableDataset is returned in the streaming mode) is applied lazily, but not You can use Dataset.with_transform to lazily apply a transform when accessing a dataset’s rows.

got it. Probably my streaming=True is why I’ve never seen not IterableDatset.

I’m slightly puzzled/curious, why do we have collate_fn option if preprocess + map does everything we need it seems (later is even lazy already!)?

can the input to the collate be called batch?

curious, what are the tradeoffs of doing tokenizer(sequences["text"], padding="max_length", max_length=128, truncation=True, return_tensors="pt") instead?

curious, how does the context length of model interact with this, will it be truncated by the HF model later if it’s too long?