Keeping IterableDataset node-wise split fixed during DDP

Thanks, yes, the dataset has many shards.

But could you give the example how to keep the dataset split fixed for each node during the whole DDP run?

The documentation had the following example for split_dataset_by_node:

from datasets.distributed import split_dataset_by_node
ids = ds.to_iterable_dataset(num_shards=512)
ids = ids.shuffle(buffer_size=10_000)  # will shuffle the shards order and use a shuffle buffer when you start iterating
ids = split_dataset_by_node(ids, world_size=8, rank=0)  # will keep only 512 / 8 = 64 shards from the shuffled lists of shards when you start iterating
dataloader = torch.utils.data.DataLoader(ids, num_workers=4)  # will assign 64 / 4 = 16 shards from this node's list of shards to each worker when you start iterating
for example in ids:
    pass

However, if I understand correctly, in this example each node would get a new random subset of shards every epoch, and then my nodes cannot to re-use the previous local cache.

Thanks again!