Let’s say I have a dataset with 5 samples with values [1, 2, 3, 4, 5], with 2 GPUs (for DDP) and batch size of 2. This dataset is IterableDataset
since I am streaming it.
Now I split the dataset using split_dataset_by_node
to ensure it doesn’t get repeated. And since it’s already splitted, I don’t have to use DistributedSampler?
But in this case I noticed that the:
First iteraton:
first GPU will get → [1, 2]
first GPU will get → [3, 4]
Second iteraton:
first GPU will get → [5]
first GPU will get → Nothing
which actually creates an issue since in case of DistributedSampler
, the samples are repeated internally to ensure non of the GPUs at any iteration is missing any data for gradient sync. So my questions are:
- Here since splitting is happening before hand, how to make sure each GPU get’s a batch at each iteration to avoid gradient sync issues?
- Do we need to use
DistributedSampler
? If yes, how? - If
dataset.n_shards % world_size != 0
, is it possible to shard the streaming dataset on the fly to avoid the case where data is missing?