How to use huggingface HF trainer train with custom collate function?

@mariosasko is this the code you had in mind?

import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True)

train_dataset = load_dataset(path, name, split="train")
train_dataset = train_dataset.map(tokenize_function, batched=True)

eval_dataset = load_dataset(path, name, split="test") 
eval_dataset = eval_dataset.map(tokenize_function, batched=True)

def collate_function(examples):
    input_ids = torch.stack([example["input_ids"] for example in examples])
    labels = torch.stack([example["input_ids"] for example in examples])
    return {"input_ids": input_ids, "labels": labels}

training_args = TrainingArguments(...)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=collate_function
)

trainer.train()

also I’m realizing that perhaps this is a mis understanding on my part…but eval_dataset (or any HF dataset) is usually an iterable So the line:

eval_dataset = eval_dataset.map(tokenize_function, batched=True)

is always applied as needed (lazily) i.e., the code NEVER maps and tokenizes until we get a batch (e.g., in the collate or someone/thing does next(iter(mapped_dataset))). Is this right?

If that is right I don’t think I even need the collate function anymore because I can simply create the prerocess function and then pass that dataset.map(preprocess, batch=True) and that won’t be called until we actually get a new batch (through a python generator).

Right?