Dataset map method - how to pass argument to the function

Hi, just started using the Huggingface library. I am wondering how can I pass model and tokenizer to my processing function along with the batch when using the map method.

def my_processing_func(batch, model, tokenizer):
–code–

I am using map like this…
new_dataset = my_dataset.map(my_processing_func, model, tokenizer, batched=True)

when I do this it does not fail but instead of passing the dictionary with input_ids and attention_mask, it passes a list of just input_ids as the batch to my_processing_func. When I remove the model and tokenizer argument then it sends the dictionary as expected.

Where am I going wrong?

Thanks in advance.

Hi! You can use fn_kwargs to pass the arguments to the map function:

new_dataset = my_dataset.map(my_processing_func, batched=True, fn_kwargs={"model": model, "tokenizer": tokenizer})

Or you can use partial:

from functools import partial
new_dataset = my_dataset.map(partial(my_processing_func, model=model, tokenizer=tokenizer), batched=True)
3 Likes

Is there any downside to using either options? If I remember correctly (?) lambdas are not picklable. So my assumption would be that if you do something like

new_dataset = my_dataset.map(lambda batch: my_processing_func(batch, model, tokenizer), batched=True)

it won’t be cached. Is that correct?

Super!! this works for me … thanks :pray:

  1. There shouldn’t be a significant difference in speed between these two approaches.
  2. We use dill, which knows how to pickle lamdas in most situations.
1 Like