MutliGPU Training using split_dataset_per_node with PyTorch Lightning

I’m using pytorch Lightning and an IterableDataset and want to divide the dataset using split_dataset_per_node

Unfortunately this isn’t possible out of the box using e.g.

from torch.distributed import get_rank, get_world_size
...
dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size())

for example because when using Lightning, the distributed environment is only setup internally to the Trainer.

my code is roughly

dataset = IterableDataset.from_generator(some_generator)

dataloader = DataLoader(dataset, batch_size=512, num_workers=5)
model = MyModel()
trainer = Trainer()
trainer.fit(model, dataloader)

So I need to somehow defer the calls to get_rank and get_world_size until the DataLoader iterated over by the Trainer.

In the case of vanilla PyTorch IterableDataset you can do this by defining you own iter function

e.g.

from torch.distributed import get_rank, get_world_size
from torch.utils.data import IterableDataset

class MyDataset(IterableDataset):
     ... 
    
    def __iter__(self):
        worker_info = get_worker_info()
        world_size = get_world_size()
        process_rank = get_rank()

        ... do stuff e.g. split data manually by rank ...

which works fine.

I tried to do the equivalent by subclassing HF IterableDataset e.g.

class IterableDatasetLightning(IterableDataset):
    def __iter__(self):
        self = split_dataset_by_node(self, rank=get_rank(), world_size=get_world_size())
        return iter(self)

But this __iter__ function is never called. Even if I put and assert False in there it will just never be called.

Reading through source code, I can’t figure out what I should be overwriting.

I’ve found a hack of a workaround

class IterableDatasetLightning(IterableDataset):
    def __init__(self, data_files):
        # HACK: creating an instance of the class within the class itself
        self.dataset = self.from_generator(generator=self.dataset_generator, gen_kwargs={"files": data_files})
    
    def dataset_generator(self, data_files):
           ... some generator that yields rows of data ... 

    def __iter__(self):
        world_size = get_world_size() if self.multi_gpu else 1
        process_rank = get_rank() if self.multi_gpu else 0
        if world_size > 1:
            self.dataset = split_dataset_by_node(self.dataset, rank=process_rank, world_size=world_size)
        return iter(self.dataset)

This works as I want it to but feels like quite a hack to create an instance of the class within the class itself.

Is there any more sane solution?