MutliGPU Training using split_dataset_per_node with PyTorch Lightning

I’m using pytorch Lightning and an IterableDataset and want to divide the dataset using split_dataset_per_node

Unfortunately this isn’t possible out of the box using e.g.

from torch.distributed import get_rank, get_world_size
dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size())

for example because when using Lightning, the distributed environment is only setup internally to the Trainer.

my code is roughly

dataset = IterableDataset.from_generator(some_generator)

dataloader = DataLoader(dataset, batch_size=512, num_workers=5)
model = MyModel()
trainer = Trainer(), dataloader)

So I need to somehow defer the calls to get_rank and get_world_size until the DataLoader iterated over by the Trainer.

In the case of vanilla PyTorch IterableDataset you can do this by defining you own iter function


from torch.distributed import get_rank, get_world_size
from import IterableDataset

class MyDataset(IterableDataset):
    def __iter__(self):
        worker_info = get_worker_info()
        world_size = get_world_size()
        process_rank = get_rank()

        ... do stuff e.g. split data manually by rank ...

which works fine.

I tried to do the equivalent by subclassing HF IterableDataset e.g.

class IterableDatasetLightning(IterableDataset):
    def __iter__(self):
        self = split_dataset_by_node(self, rank=get_rank(), world_size=get_world_size())
        return iter(self)

But this __iter__ function is never called. Even if I put and assert False in there it will just never be called.

Reading through source code, I can’t figure out what I should be overwriting.

I’ve found a hack of a workaround

class IterableDatasetLightning(IterableDataset):
    def __init__(self, data_files):
        # HACK: creating an instance of the class within the class itself
        self.dataset = self.from_generator(generator=self.dataset_generator, gen_kwargs={"files": data_files})
    def dataset_generator(self, data_files):
           ... some generator that yields rows of data ... 

    def __iter__(self):
        world_size = get_world_size() if self.multi_gpu else 1
        process_rank = get_rank() if self.multi_gpu else 0
        if world_size > 1:
            self.dataset = split_dataset_by_node(self.dataset, rank=process_rank, world_size=world_size)
        return iter(self.dataset)

This works as I want it to but feels like quite a hack to create an instance of the class within the class itself.

Is there any more sane solution?