@mariosasko is this the code you had in mind?
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, Trainer, TrainingArguments
from datasets import load_dataset
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def tokenize_function(examples):
return tokenizer(examples["text"], truncation=True)
train_dataset = load_dataset(path, name, split="train")
train_dataset = train_dataset.map(tokenize_function, batched=True)
eval_dataset = load_dataset(path, name, split="test")
eval_dataset = eval_dataset.map(tokenize_function, batched=True)
def collate_function(examples):
input_ids = torch.stack([example["input_ids"] for example in examples])
labels = torch.stack([example["input_ids"] for example in examples])
return {"input_ids": input_ids, "labels": labels}
training_args = TrainingArguments(...)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=collate_function
)
trainer.train()
also I’m realizing that perhaps this is a mis understanding on my part…but eval_dataset
(or any HF dataset) is usually an iterable So the line:
eval_dataset = eval_dataset.map(tokenize_function, batched=True)
is always applied as needed (lazily
) i.e., the code NEVER maps and tokenizes until we get a batch (e.g., in the collate or someone/thing does next(iter(mapped_dataset))
). Is this right?
If that is right I don’t think I even need the collate function anymore because I can simply create the prerocess function and then pass that dataset.map(preprocess, batch=True)
and that won’t be called until we actually get a new batch (through a python generator).
Right?