Avoiding hashing in `map`

I have a function (let’s call it f) that involves running an LLM on strings from the IMDB movie reviews dataset. However, when I run

imdb.map(f)

It grinds of a halt when I am using a large model. I think the problem is that map is trying to hash f which references a large model in its body, and this is very expensive. I found that setting new_fingerprint='foo', load_from_cache_file=False fixes this, but my understanding is that this still creates a cache file that will never be used.

I’ve tried using datasets.disable_caching() before running map, but it still appears to be trying to hash f.

What is the right way to do this? I would have thought that mapping LLMs over datasets is a common use case, and was surprised that this (obscure?) hashing problem doesn’t have an obvious workaround (unless I’m missing something)?

Thanks

1 Like
  1. Use a Simple Lambda or Wrapper for the Function:
def f_wrapper(example):
    model = load_model()  # Load the model in this function
    return model(example['text'])  # Process the example with the model
  1. Use batched=True in map():
imdb = imdb.map(f_wrapper, batched=True, load_from_cache_file=False)
  1. Avoid Using datasets.map() Entirely:
results = []
for example in imdb:
    result = f(example)
    results.append(result)
  1. Caching at the Function Level:
model = None

def f(example):
    global model
    if model is None:
        model = load_model()  # Load the model only once
    return model(example['text'])
  1. Disable the hashing in map():
imdb = imdb.map(f, new_fingerprint='foo', load_from_cache_file=False)
  1. Manual Parallelization:
imdb = imdb.map(f_wrapper, batched=True, num_proc=4, load_from_cache_file=False)
1 Like