How to handle streaming datasets with DDP?

Let’s say I have a dataset with 5 samples with values [1, 2, 3, 4, 5], with 2 GPUs (for DDP) and batch size of 2. This dataset is IterableDataset since I am streaming it.

Now I split the dataset using split_dataset_by_node to ensure it doesn’t get repeated. And since it’s already splitted, I don’t have to use DistributedSampler?

But in this case I noticed that the:

First iteraton:
first GPU will get → [1, 2]
first GPU will get → [3, 4]

Second iteraton:
first GPU will get → [5]
first GPU will get → Nothing

which actually creates an issue since in case of DistributedSampler, the samples are repeated internally to ensure non of the GPUs at any iteration is missing any data for gradient sync. So my questions are:

  1. Here since splitting is happening before hand, how to make sure each GPU get’s a batch at each iteration to avoid gradient sync issues?
  2. Do we need to use DistributedSampler? If yes, how?
  3. If dataset.n_shards % world_size != 0, is it possible to shard the streaming dataset on the fly to avoid the case where data is missing?
1 Like

I have the same question