Expected memory usage of Dataset

Hello. I am having some difficulty with the memory usage of my Dataset. I’ve been working with huggingface for a while, so I think I’ve read pretty much everything the docs have provided about this topic. From Big data? Datasets to the rescue! I am under the impression that even a huge Dataset should not require an excessive amount of memory, a.k.a., RAM, because it is mostly stored on disk. However, in my experiments, I keep on running out of memory. Here is essentially the code I am working with.

def print_mem():
    gig = psutil.Process(os.getpid()).memory_info().rss / 1024 ** 3
    print(f"{gig} gigabytes")

print_mem()
dataset = Dataset.from_generator(
    iterable,
    keep_in_memory=KEEP_IN_MEMORY,
    writer_batch_size=WRITER_BATCH_SIZE,
)
print_mem()
dataset = dataset.map(
    fn,
    batched=True,
    keep_in_memory=KEEP_IN_MEMORY,
    writer_batch_size=WRITER_BATCH_SIZE,
)
print_mem():

As expected, when KEEP_IN_MEMORY = True, I run out of memory and the process is killed. However, when I set KEEP_IN_MEMORY = False, I find that the dataset still takes a whopping 24 gigabytes of memory! I read in some other exchanges about the importance of the writer_batch_size parameter and how it controls the underlying arrow tables and memory mapping. Since my dataset consists of very very long documents, the default value of 1000 might be too large for my use case. I’ve tried WRITER_BATCH_SIZE = 10, WRITER_BATCH_SIZE = 100, and of course WRITER_BATCH_SIZE = 1000. All yield the same result: 24 gigabytes of memory are used by the dataset. I am confused about why this is happening and would appreciate any advice about why the dataset takes so much memory even though it is supposed to be stored on disk.

One factor that I’m worried could potentially be causing an issue is the fact that I am working on a cluster using the SLURM manager. I’m aware of how to allocate enough memory to jobs using SLURM’s sbatch, but I’m wondering if the fact that the job is executing on a foreign node could be problematic.

UPDATE

I also tried this with WRITER_BATCH_SIZE = 1 and WRITER_BATCH_SIZE = 200. All of the sizes seem to have little influence over the resulting memory usage (24G - 30G). Here is a link to one of my documents. As you can see, it is a very long document. The function passed to map chops each of these documents into pieces of length 4096.

The issue was that I iterated through the dataset and computed some statistics about its elements using a simple Python for-loop. This apparently caused the entire dataset to be moved into memory and stay there, which is not something I anticipated would happen. Simple deleting that line of code resolved the issue entirely.