I am trying to implement a customized data collator that tokenizes batches on-the-fly. But somehow, features
passed to it does not have any field/key other than label
. Hence, in my code below, when trying to extract the text for tokenization, I get the error KeyError: 'text1'
I read the source code of collators in transformers
But I don’t see where non-label
fields are dropped. Can someone help me?
Below is my code which uses a hardcoded dataset of two keys/fields, namely text1
and label
:
from typing import List, Dict, Tuple, Literal, Any
import torch
import datasets
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, DataCollatorWithPadding
train_ds = datasets.Dataset.from_dict(
{"text1":["A", "B", "C", "D", "E"],
"label":[0, 0, 0, 1, 1]
})
class SmartCollator(DataCollatorWithPadding):
"""Tokenize each batch on the fly"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def __call__(self, features: List[Dict[str, Any]]):
texts = [f["text1"] for f in features]
labels = [f["label"] for f in features]
encodings = self.tokenizer(texts, truncation=True, padding="longest", max_length=20, return_tensors="pt")
return encodings.update({
'labels': torch.tensor(labels)
})
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=1,
per_device_train_batch_size=2,
report_to="none"
)
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
trainer = Trainer(
model=model,
args=training_args,
data_collator=SmartCollator(tokenizer=tokenizer),
train_dataset=train_ds
)
trainer.train()