datasets.Dataset.map() idle processes when multiprocessing

I’m running datasets.Dataset.map() with num_proc=64, and after a while the cpu utilization falls far below 100% (3.16%).
This suggests workers are assigned a list of jobs at the beginning, leaving them idle when they’re done with that list, instead of taking on one of the remaining jobs on demand.

Is my intuition correct? And if so, what’s the thought behind this design choice? Wouldn’t it be more optimal to keep a jobs pool available to all workers so they all work until there’s no more jobs left?

Thank you

Hi! When num_proc > 1, map splits the dataset into num_proc shards, each of which is mapped to one of the num_proc workers. So in your case, this means that some workers finished processing their shards earlier than others.

Thanks, just like I thought.

So my question then is: what’s the thought behind this design choice? Surely implementing a multiprocessing queue of rows/batches to take on would prevent processes from going idle due to the workload not being evenly distributed across shards?

Also, is there any guarantee on the way the dataset is currently sharded? Is the dataset simply split in consecutive contiguous shards of size~len(data)/num_proc where worker 0 has from 0 to len(data)/num_proc, worker 1 has from len(data)/num_proc to len(data)*2/num_proc and so on?

EDIT: looks like it’s contiguous, from source:
self.shard(num_shards=num_proc, index=rank, contiguous=True, keep_in_memory=keep_in_memory)

You can shuffle the dataset before map to distribute the workload evenly.

So my question then is: what’s the thought behind this design choice? Surely implementing a multiprocessing queue of rows/batches to take on would prevent processes from going idle due to the workload not being evenly distributed across shards?

We can consider this, but this would definitely make our code more complex and harder to maintain.

I see, thanks! Shuffling would probably make things better, but I’d still cause idle workers by the end if the variance of workload between rows is high. I would be happy to contribute, given I’m working on a custom multiprocessing solution as a workaround to this behavior.
I’d appreciate if you could point me at what needs attention, or explain what your concerns are around complexity/maintainability.

You are more than welcome to contribute. Feel free to open an issue/PR if you need some help/pointers.

Unless we drop Pool in the multiprocess map, there will always be some idle workers by the end since Pool has to maintain the specified number of workers (and there is no public API to change this). But replacing Pool with some other logic would most likely make our code more complex, which is probably not worth it unless this leads to significant savings performance-wise.