I too am having similar issues. I find that while map
is processing my dataset, memory gradually fills up until I get an OOM error after a few hours. Iāve tried playing around with batch_size, num_proc, and writer_batch size extensively to no avail (including with each value set to 1!). I then tried sharding the dataset into 64 shards and applying map to each one individually in a for-loop, as can be seen below:
def mem():
m = psutil.Process(os.getpid()).memory_info().rss / (1024**3)
return f"{round(m, 2)}G"
N = 64
print("INITIAL:", mem())
print(f"DATASET: {round(dataset.dataset_size / (1024 ** 3), 2)}G")
shards = [dataset.shard(N, i) for i in range(N)]
for i, d in enumerate(shards):
print(f"{i}:", mem())
d.map(tokenize_fn, batched=True)
This yields an unexpected log, namely
INITIAL: 0.5G
DATASET: 24.0G
0: 0.5G
1: 17.2G
2: 23.0G
3: 29.5G
4: 34.5G
...
The steady increase in memory utilization is further reflected by watching with top
. Maybe Iām misunderstanding how the memory-mapping works, but I would have expected the amount of memory to remain relatively constant, as each shard is processed in-memory, then relinquished. When I add del d
or even gc.collect()
(where gc is Python garbage collection library) there is not any difference in memory.
One additional unexpected behavior occurs when re-running the code sample. Suppose I ran the first code sample until 5 shards had been processed and the log was printed as depicted above. After that suppose I re-run the exact same code. This time, datasets finds cache files in my ~/.cache directory because the map
hash matches. In this circumstance, the following log would be printed
INITIAL: 0.5G
DATASET: 24.0G
0: 0.5G
1: 0.5G
2: 0.5G
3: 0.5G
4: 0.5G
5: 39.8G
6:45.7G
7: 51.2G
...
From this output, it appears to me like memory-mapping works as expected until the sixth iteration, at which point it just āgives upā and loads everything into memory? What is weird is that at this point, if I individually process each shard by executing my script once, I could feasibly merge them all together at the end of this with the concatenate function.
Basically, what I would like is to be able to preprocess my dataset without having the entire thing in RAM at once, which is not feasible for my current hardware setup. I can provide more specific details about my code if its helpful, but at the moment Iām assuming this is just a misunderstanding on my part of how memory mapping works. Thanks!