MutliGPU Training using split_dataset_per_node with PyTorch Lightning

So I’m not a pro here, but maybe relevant:

  • webdataset solves the issue of distributing across nodes with a dataset(…, nodesplitter_func=nodesplitter_func) approach. This nodesplitter_func can be anything but typically does torch.distributed.get_rank etc, like you tried
  • You can/should place the dataset() function inside a datamodule
  • The distributed environment exists by the time datamodule.setup() is called, and functions called inside the datamodule get_dataloader functions can read rank/world
  • nodesplitter_func is called when each worker is set up, but that’s okay as it returns the same list for workers on a given node (being deterministic)

I think the same thing might work for split_dataset_by_node here. You would need to split the dataset inside webdatamodule and after webdatamodule.setup(), and then the distributed environment might be set up for you.