Distributed data sampling for streaming

I鈥檓 reading data using stream and I need to pass the data to a pipeline which is ran in distributed manner, where each processes is expected to handle different batch of data.

when I tried the following

dataset = load_dataset("oscar-corpus/OSCAR-2301",
                        token= token

dataloader= iter(DataLoader(dataset, num_workers=5,batch_size = 1000,collate_fn = lambda x: [i for i in x]))

        inputs= dataloader, # any inputs of type Iterable 

it didn鈥檛 work, the dataloader was replicated across processes, and processes ended up with the same batch of data.

Hi! You should be able to avoid this data duplication by using split_dataset_by_node as explained in IterableDataset returns duplicated data using PyTorch DDP 路 Issue #5360 路 huggingface/datasets 路 GitHub.

Thank you, that solved the issue