laro1
August 8, 2022, 12:45pm
1
I want to call DatasetDict
map
function with parameters, and I dont know how to do it.
I have function with the following API:
def tokenize_function(tokenizer, examples):
s1 = examples["premise"]
s2 = examples["hypothesis"]
args = (s1, s2)
return tokenizer(*args, padding="max_length", truncation=True)
And when I’m trying to use in this way:
dataset = load_dataset("json", data_files=data_files)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_datasets = dataset.map(tokenize_function, tokenizer, batched=True)
I’m getting error:
TypeError: list indices must be integers or slices, not str
How can I call map
function in my example ?
nielsr
August 8, 2022, 5:03pm
2
The function shouldn’t take the tokenizer as input, only the examples. You can add additional arguments to the function, which can then be passed as fn_kwargs in the map method.
Option 1: define tokenizer outside the function
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("json", data_files=data_files)
def tokenize_function(examples):
s1 = examples["premise"]
s2 = examples["hypothesis"]
args = (s1, s2)
return tokenizer(*args, padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, tokenizer, batched=True)
Option 2: use fn_kwargs
to pass the tokenizer to the function
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
dataset = load_dataset("json", data_files=data_files)
def tokenize_function(examples, tokenizer):
s1 = examples["premise"]
s2 = examples["hypothesis"]
args = (s1, s2)
return tokenizer(*args, padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, fn_kwargs={"tokenizer":tokenizer}, batched=True)
1 Like
how can i do a function on Dataset as batch here my code:
def get_sent(examples):
sents =sentiment_classifier(examples[‘text’])
labels=[sent[0][‘label’] for sent in sents]
examples[‘label’]=labels
return examples
to_label1 = to_label1.map(get_sent,batched=True)
when i run the code i got error and cant run function as batch.
Can you paste the error message?