I’m using datasets.IterableDataset (more specifically IterableDataset.from_generator). I have been using it with DDP Pytorch by streaming all the data across multiple GPU nodes from a source as caching the whole dataset locally requires too much disk space.
However, if I could shard the dataset per DDP node, then my data could fit on disk. So I am wondering is it possible use IterableDataset.from_generator so that:
Each DDP node gets assigned a fixed subset of the shards (e.g., based on a seed) for the whole run.
and each DDP node shuffles its shards for each epoch.
You can use shuffle and set_epoch to shuffle the shards and samples in between epochs (explained here in the docs) and split_dataset_by_node to split the dataset across nodes.
For this to work efficiently, the dataset must consist of many shards (n_shards returns the number of shards; dataset.n_shards % world_size == 0 is the ideal number). An example of creating a sharded dataset is available here (shards are formed by sharding gen_kwargs’s list values).
But could you give the example how to keep the dataset split fixed for each node during the whole DDP run?
The documentation had the following example for split_dataset_by_node:
from datasets.distributed import split_dataset_by_node
ids = ds.to_iterable_dataset(num_shards=512)
ids = ids.shuffle(buffer_size=10_000) # will shuffle the shards order and use a shuffle buffer when you start iterating
ids = split_dataset_by_node(ids, world_size=8, rank=0) # will keep only 512 / 8 = 64 shards from the shuffled lists of shards when you start iterating
dataloader = torch.utils.data.DataLoader(ids, num_workers=4) # will assign 64 / 4 = 16 shards from this node's list of shards to each worker when you start iterating
for example in ids:
pass
However, if I understand correctly, in this example each node would get a new random subset of shards every epoch, and then my nodes cannot to re-use the previous local cache.
from datasets.distributed import split_dataset_by_node
ids = ds.to_iterable_dataset(num_shards=512)
ids = ids.shuffle(buffer_size=10_000) # will shuffle the shards order and use a shuffle buffer when you start iterating
ids = split_dataset_by_node(ids, world_size=8, rank=0) # will keep only 512 / 8 = 64 shards from the shuffled lists of shards when you start iterating
dataloader = torch.utils.data.DataLoader(ids, num_workers=4) # will assign 64 / 4 = 16 shards from this node's list of shards to each worker when you start iterating
for epoch in range(num_epochs):
ids.set_epoch(epoch) # operates on the dataset split
for example in dataloader:
pass
I think this is for consistency with PyTorch’s DistributedSampler.set_epoch (@lhoestq should know more as the author of this feature).
If the dataset in question is a Hub dataset without a loading script, one option is to manually shuffle and split the files/shards to ensure they stay fixed in each epoch.