So I’m not a pro here, but maybe relevant:
- webdataset solves the issue of distributing across nodes with a dataset(…, nodesplitter_func=nodesplitter_func) approach. This nodesplitter_func can be anything but typically does torch.distributed.get_rank etc, like you tried
- You can/should place the dataset() function inside a datamodule
- The distributed environment exists by the time datamodule.setup() is called, and functions called inside the datamodule get_dataloader functions can read rank/world
- nodesplitter_func is called when each worker is set up, but that’s okay as it returns the same list for workers on a given node (being deterministic)
I think the same thing might work for split_dataset_by_node here. You would need to split the dataset inside webdatamodule and after webdatamodule.setup(), and then the distributed environment might be set up for you.