I’m using pytorch Lightning and an IterableDataset and want to divide the dataset using split_dataset_per_node
Unfortunately this isn’t possible out of the box using e.g.
from torch.distributed import get_rank, get_world_size
...
dataset = split_dataset_by_node(dataset, rank=get_rank(), world_size=get_world_size())
for example because when using Lightning, the distributed environment is only setup internally to the Trainer.
my code is roughly
dataset = IterableDataset.from_generator(some_generator)
dataloader = DataLoader(dataset, batch_size=512, num_workers=5)
model = MyModel()
trainer = Trainer()
trainer.fit(model, dataloader)
So I need to somehow defer the calls to get_rank
and get_world_size
until the DataLoader iterated over by the Trainer.
In the case of vanilla PyTorch IterableDataset you can do this by defining you own iter function
e.g.
from torch.distributed import get_rank, get_world_size
from torch.utils.data import IterableDataset
class MyDataset(IterableDataset):
...
def __iter__(self):
worker_info = get_worker_info()
world_size = get_world_size()
process_rank = get_rank()
... do stuff e.g. split data manually by rank ...
which works fine.
I tried to do the equivalent by subclassing HF IterableDataset e.g.
class IterableDatasetLightning(IterableDataset):
def __iter__(self):
self = split_dataset_by_node(self, rank=get_rank(), world_size=get_world_size())
return iter(self)
But this __iter__
function is never called. Even if I put and assert False
in there it will just never be called.
Reading through source code, I can’t figure out what I should be overwriting.
I’ve found a hack of a workaround
class IterableDatasetLightning(IterableDataset):
def __init__(self, data_files):
# HACK: creating an instance of the class within the class itself
self.dataset = self.from_generator(generator=self.dataset_generator, gen_kwargs={"files": data_files})
def dataset_generator(self, data_files):
... some generator that yields rows of data ...
def __iter__(self):
world_size = get_world_size() if self.multi_gpu else 1
process_rank = get_rank() if self.multi_gpu else 0
if world_size > 1:
self.dataset = split_dataset_by_node(self.dataset, rank=process_rank, world_size=world_size)
return iter(self.dataset)
This works as I want it to but feels like quite a hack to create an instance of the class within the class itself.
Is there any more sane solution?