Fetching rows of a large Dataset by index

I was referred here by @lhoestq from this github issue.

Background
I have a large dataset, ds_all_utts, of user utterances. I load it using load_from_disk because I saved it with save_to_disk:

ds_all_utts = load_from_disk(ds_all_utts_fname)

ds_all_utts has 2,732,013 rows and these features:

{'ANY': Value(dtype='int64', id=None),
 'COMPLAINTCLARIFICATION': Value(dtype='int64', id=None),
 'COMPLAINTMISHEARD': Value(dtype='int64', id=None),
 'COMPLAINTPRIVACY': Value(dtype='int64', id=None),
 'COMPLAINTREPETITION': Value(dtype='int64', id=None),
 'CRITICISM': Value(dtype='int64', id=None),
 'NEGATIVENAVIGATION': Value(dtype='int64', id=None),
 'OFFENSIVE': Value(dtype='int64', id=None),
 'STOP': Value(dtype='int64', id=None),
 'embedding': Sequence(feature=Value(dtype='float32', id=None), length=-1, id=None),
 'frequency': Value(dtype='int64', id=None),
 'user_utterance': Value(dtype='string', id=None)}

user_utterance is a short piece of text (usually just a few words), embedding is a 1280-length vector representing that utterance, frequency is an int, and the rest are binary labels (0 or 1) for the utterance. It’s sorted by descending frequency.

I have another Dataset called neuralgen_ds whose rows represent turns of dialogue along with their context. It has 385,580 rows and these features:

{'session_id': Value(dtype='string', id=None),
 'treelet': Value(dtype='string', id=None),
 'context': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'bot_utt': Value(dtype='string', id=None),
 'bot_utt_labels': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'user_utt': Value(dtype='string', id=None),
 'user_utt_labels': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None),
 'GPT2ED': Value(dtype='bool', id=None),
 '__index_level_0__': Value(dtype='int64', id=None)}

Of these, the important one is user_utt, which is the same type of data as ds_all_utts['user_utterance']. Some user utterances appear multiple times in neuralgen_ds; there are 190,602 unique utterances in neuralgen_ds['user_utt'].

What I want to do
For each row of neuralgen_ds, I want to look up the user utterance in ds_all_utts, and copy over certain columns into neuralgen_ds. In particular, I want to copy over embedding and all the capitalized binary labels (ANY, COMPLAINTCLARIFICATION, etc).

My code
First I create a dictionary mapping from a user utterance to its position in ds_all_utts:

ds_all_utts_lookup = {utt: idx for idx, utt in enumerate(ds_all_utts['user_utterance'])}

Then I use .map to add the columns to neuralgen_ds:

cols = ['embedding', 'ANY', 'COMPLAINTCLARIFICATION', 'COMPLAINTMISHEARD', 'COMPLAINTPRIVACY', 'COMPLAINTREPETITION', 'CRITICISM', 'NEGATIVENAVIGATION', 'OFFENSIVE', 'STOP']

def map_fn(examples):
    user_utts = examples['user_utt']  # list of str
    idxs = [ds_all_utts_lookup[user_utt] for user_utt in user_utts]  # list of int
    ds_slice = ds_all_utts[idxs]  # dict
    result = {col: ds_slice[col] for col in cols}
    return result

neuralgen_ds = neuralgen_ds.map(map_fn, batched=True, batch_size=100)

The tqdm estimate says this .map will take over 8 hours. Adjusting batch_size doesn’t seem to help. The slowest part of map_fn is this line:

ds_slice = ds_all_utts[idxs]  # dict

Other questions

Are you on a SSD or an HDD ?

I’m not sure, but I followed these instructions and got

>>> lsblk -o name,rota
NAME   ROTA
sda       1
├─sda1    1
├─sda2    1
├─sda5    1
├─sda6    1
├─sda7    1
└─sda8    1
sdb       1
└─sdb1    1
1 Like

Hi !

To speed up map operations, you can run it with multiprocessing with specifying num_proc= in map. Usually it’s better to set it to the number of cores your CPU has.

Let me know if that helps.

Also it looks like you’re using a HDD, which is slower than an SSD. Since your script does a lot of read/write operations (writing a dataset when reading data from another one), I’d expect that the HDD slows down the process a bit unfortunately.

1 Like

Hi @lhoestq,

I tried raising num_proc but this didn’t seem to help.

Generally, I think the problem can be abstracted away from the .map function, or even from the need to write data to a new dataset. A simplified version of my problem just requires me to fetch a certain ~300k rows from my large ds_all_utts dataset (and just hold the result in memory - I have enough). If I could do that efficiently, then my problem would be solved.

As you noted in this github issue, querying rows of a dataset gets slower as the dataset gets bigger, so that’s my fundamental problem.

Do you have any ideas for how I could more efficiently query my large dataset? You mentioned that you can access 6 million wikipedia articles in less than a minute on your laptop - what format are you saving that dataset, and how are you accessing it?

The optimal way to query examples is to query slices of contiguous data. This is what I use to iterate over wikipedia quickly. Are the idxs that you want to fetch contiguous ?

Are the idxs that you want to fetch contiguous ?

No. The large dataset contains pre-computed embeddings (and other labels) for all user utterances. Given any particular user utterance (or in this scenario, a particular 300k utterances), I want to be able to fetch its pre-computed embedding from the big dataset.

I don’t think there’s an easy way to make it faster without diving into Arrow optimization stuff.

Though I’m curious, how much time does it take to fetch random indices from ds_all_utts in average ? Usually even for big datasets it stays below 1ms.

Though I’m curious, how much time does it take to fetch random indices from ds_all_utts in average ? Usually even for big datasets it stays below 1ms.

I ran this:

import time

for idx in range(0, len(ds_all_utts), 100000):
    t0 = time.time()
    row = ds_all_utts[idx]
    print(f"fetching idx {idx} of {len(ds_all_utts)} took {time.time()-t0} seconds")

…and found that it takes longer for higher indexes:

fetching idx 0 of 2732013 took 0.0025310516357421875 seconds
fetching idx 100000 of 2732013 took 0.018671512603759766 seconds
fetching idx 200000 of 2732013 took 0.027704238891601562 seconds
fetching idx 300000 of 2732013 took 0.03670859336853027 seconds
fetching idx 400000 of 2732013 took 0.04753923416137695 seconds
fetching idx 500000 of 2732013 took 0.06029653549194336 seconds
fetching idx 600000 of 2732013 took 0.07588982582092285 seconds
fetching idx 700000 of 2732013 took 0.08530783653259277 seconds
fetching idx 800000 of 2732013 took 0.09908699989318848 seconds
fetching idx 900000 of 2732013 took 0.10945987701416016 seconds
fetching idx 1000000 of 2732013 took 0.11123108863830566 seconds
fetching idx 1100000 of 2732013 took 0.10822105407714844 seconds
fetching idx 1200000 of 2732013 took 0.11907958984375 seconds
fetching idx 1300000 of 2732013 took 0.13410329818725586 seconds
fetching idx 1400000 of 2732013 took 0.15185165405273438 seconds
fetching idx 1500000 of 2732013 took 0.20782732963562012 seconds
fetching idx 1600000 of 2732013 took 0.18477439880371094 seconds
fetching idx 1700000 of 2732013 took 0.3052663803100586 seconds
fetching idx 1800000 of 2732013 took 0.21129512786865234 seconds
fetching idx 1900000 of 2732013 took 0.22873163223266602 seconds
fetching idx 2000000 of 2732013 took 0.24343037605285645 seconds
fetching idx 2100000 of 2732013 took 0.25824809074401855 seconds
fetching idx 2200000 of 2732013 took 0.34645700454711914 seconds
fetching idx 2300000 of 2732013 took 0.29653358459472656 seconds
fetching idx 2400000 of 2732013 took 0.3089408874511719 seconds
fetching idx 2500000 of 2732013 took 0.32140231132507324 seconds
fetching idx 2600000 of 2732013 took 0.3361659049987793 seconds
fetching idx 2700000 of 2732013 took 0.41907739639282227 seconds

Indeed this is slower that expected.
May I suggest to run ds_all_utts.set_format("numpy", columns=["embedding"]) first and re-run your benchmark to see if it speeds up things ?
Arrow has a NumPy integration that allows to read data with zero-copy, so this should increase the speed.

Doing ds_all_utts.set_format("numpy", columns=["embedding"]) doesn’t help much:

fetching idx 0 of 2732013 took 0.0039942264556884766 seconds
fetching idx 100000 of 2732013 took 0.03906583786010742 seconds
fetching idx 200000 of 2732013 took 0.0478971004486084 seconds
fetching idx 300000 of 2732013 took 0.06415033340454102 seconds
fetching idx 400000 of 2732013 took 0.08065438270568848 seconds
fetching idx 500000 of 2732013 took 0.0943913459777832 seconds
fetching idx 600000 of 2732013 took 0.08204221725463867 seconds
fetching idx 700000 of 2732013 took 0.11144042015075684 seconds
fetching idx 800000 of 2732013 took 0.10394906997680664 seconds
fetching idx 900000 of 2732013 took 0.11541366577148438 seconds
fetching idx 1000000 of 2732013 took 0.1507728099822998 seconds
fetching idx 1100000 of 2732013 took 0.1657857894897461 seconds
fetching idx 1200000 of 2732013 took 0.1772749423980713 seconds
fetching idx 1300000 of 2732013 took 0.17113184928894043 seconds
fetching idx 1400000 of 2732013 took 0.2014932632446289 seconds
fetching idx 1500000 of 2732013 took 0.21463871002197266 seconds
fetching idx 1600000 of 2732013 took 0.2281172275543213 seconds
fetching idx 1700000 of 2732013 took 0.21286630630493164 seconds
fetching idx 1800000 of 2732013 took 0.22928833961486816 seconds
fetching idx 1900000 of 2732013 took 0.25535106658935547 seconds
fetching idx 2000000 of 2732013 took 0.26118993759155273 seconds
fetching idx 2100000 of 2732013 took 0.3325386047363281 seconds
fetching idx 2200000 of 2732013 took 0.3626432418823242 seconds
fetching idx 2300000 of 2732013 took 0.37558627128601074 seconds
fetching idx 2400000 of 2732013 took 0.3952178955078125 seconds
fetching idx 2500000 of 2732013 took 0.3839607238769531 seconds
fetching idx 2600000 of 2732013 took 0.35515308380126953 seconds
fetching idx 2700000 of 2732013 took 0.38485121726989746 seconds

Is this behaviour expected - slower lookup for rows later in the dataset? It says here that random access is O(1) :

So caching the dataset directly on disk can use memory-mapping and pay effectively zero cost with O(1) random access. The default in :hugs:Datasets is thus to always memory-map dataset on drive.

Does the fact that I’m loading my dataset using load_from_disk have any effect on this?

It’s not O(1) for some reason. It’s a current limitation of our usage of Arrow (or Arrow itself maybe). I’m in contact with the Arrow team to find a solution.

I’ll update this thread when I have more info.

Also since using numpy with zero-copy doesn’t speed up the reading, this makes me think that the bottleneck is the read speed of your hard drive, combined with the Arrow behavior that makes it take more time to fetch examples at higher indexes.

1 Like

I just got confirmation from the arrow team that it’s not O(1).
(here is the source if you’re interested)

We’ll work on a solution to speed this up, now that we know what causes this latency !
For example the Arrow implementation in Julia already provides a much faster access time, that is not linear.