I tried everything but getting tokenizer error, here is my simple code:
from datasets import Dataset, load_dataset
from transformers import AutoTokenizer, BertTokenizer,DataCollatorForLanguageModeling, Trainer, TrainingArguments, AutoModelForCausalLM
context_dataset = Dataset.from_dict({
“text”: [“context sentence 1”, “context sentence 2”, “context sentence 3”]
})
print(context_dataset)
Load a pre-trained tokenizer
tokenizer = AutoTokenizer.from_pretrained(“distilbert-base-uncased-distilled-squad”)
print(tokenizer)
Tokenize your dataset
def tokenize_function(examples):
return tokenizer(examples[“text”], padding=“max_length”, truncation=True)
tokenized_datasets = context_dataset.map(tokenize_function, batched=True, num_proc=3, remove_columns=[“text”])
The reason of that error is that the tokenizer was not declared within the tokenize_function
function.
A simple solution is to declare the tokenizer instance inside the tokenize_function
function, but this is inefficient because it creates a new tokenizer instance every time the function is called. Other simple solution is to add tokenizer argument on tokenize_function
function, but the function
argument of Datasets’s map
method only accepts one argument(examples).
There are two solutions to solve this problem, as described below.
solution 1. Use TokenizerWrapper class
class TokenizerWrapper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def tokenize_function(self, examples):
return self.tokenizer(
examples["text"],
padding="max_length",
truncation=True,
)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
tokenizer_wrapper = TokenizerWrapper(tokenizer)
tokenized_dataset = context_dataset.map(tokenizer_wrapper.tokenize_function, batched=True, num_proc=3, remove_columns=["text"])
solution 2. Use partial function
from functools import partial
from transformers import AutoTokenizer
def tokenize_function(tokenizer, examples):
return tokenizer(examples["text"], padding="max_length", truncation=True)
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad")
partial_tokenize_function = partial(tokenize_function, tokenizer)
tokenized_dataset = context_dataset.map(partial_tokenize_function, batched=True, num_proc=3, remove_columns=["text"])