MutliGPU Training using split_dataset_per_node with PyTorch Lightning

I’m using pytorch Lightning and an IterableDataset and want to divide the dataset using split_dataset_per_node

Unfortunately this isn’t possible out of the box using e.g.

from torch.distributed import get_rank, get_world_size
...
dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size())

for example because when using Lightning, the distributed environment is only setup internally to the Trainer.

my code is roughly

dataset = IterableDataset.from_generator(some_generator)

dataloader = DataLoader(dataset, batch_size=512, num_workers=5)
model = MyModel()
trainer = Trainer()
trainer.fit(model, dataloader)

So I need to somehow defer the calls to get_rank and get_world_size until the DataLoader iterated over by the Trainer.

In the case of vanilla PyTorch IterableDataset you can do this by defining you own iter function

e.g.

from torch.distributed import get_rank, get_world_size
from torch.utils.data import IterableDataset

class MyDataset(IterableDataset):
     ... 
    
    def __iter__(self):
        worker_info = get_worker_info()
        world_size = get_world_size()
        process_rank = get_rank()

        ... do stuff e.g. split data manually by rank ...

which works fine.

I tried to do the equivalent by subclassing HF IterableDataset e.g.

class IterableDatasetLightning(IterableDataset):
    def __iter__(self):
        self = split_dataset_by_node(self, rank=get_rank(), world_size=get_world_size())
        return iter(self)

But this __iter__ function is never called. Even if I put and assert False in there it will just never be called.

Reading through source code, I can’t figure out what I should be overwriting.

I’ve found a hack of a workaround

class IterableDatasetLightning(IterableDataset):
    def __init__(self, data_files):
        # HACK: creating an instance of the class within the class itself
        self.dataset = self.from_generator(generator=self.dataset_generator, gen_kwargs={"files": data_files})
    
    def dataset_generator(self, data_files):
           ... some generator that yields rows of data ... 

    def __iter__(self):
        world_size = get_world_size() if self.multi_gpu else 1
        process_rank = get_rank() if self.multi_gpu else 0
        if world_size > 1:
            self.dataset = split_dataset_by_node(self.dataset, rank=process_rank, world_size=world_size)
        return iter(self.dataset)

This works as I want it to but feels like quite a hack to create an instance of the class within the class itself.

Is there any more sane solution?

So I’m not a pro here, but maybe relevant:

  • webdataset solves the issue of distributing across nodes with a dataset(…, nodesplitter_func=nodesplitter_func) approach. This nodesplitter_func can be anything but typically does torch.distributed.get_rank etc, like you tried
  • You can/should place the dataset() function inside a datamodule
  • The distributed environment exists by the time datamodule.setup() is called, and functions called inside the datamodule get_dataloader functions can read rank/world
  • nodesplitter_func is called when each worker is set up, but that’s okay as it returns the same list for workers on a given node (being deterministic)

I think the same thing might work for split_dataset_by_node here. You would need to split the dataset inside webdatamodule and after webdatamodule.setup(), and then the distributed environment might be set up for you.