Creating sharded IterableDataset from a list of IterableDatasets?

Given a list of IterableDatasets datasets (len = n), each with n_shards == 1, is it possible to create a single IterableDataset ds with ds.n_shards == n? I have a really roundabout way of doing so:

        def gen(datasets):
            print("Creating generator")
            for idx, ds in enumerate(datasets):
                print(f"Starting dataset {idx}: {ds.name}")
                for item_idx, item in enumerate(ds):
                    yield item
                print(f"Ending dataset {idx}: {ds.name}")
            print("Generator exhausted")
        ds = IterableDataset.from_generator(gen, gen_kwargs={"datasets": datasets})

But this code is not ideal as it doesn’t play well with DataLoader when num_workers > 0.

1 Like

I had the exact same problem. HF’s datasets.interleave_datasets() can deal with a list of iterabledatasets but the returned iterabledataset will have n_shards being the smallest of the list. In your case it is 1. So the workaround to achieve the goal is to pre-process so that all iterabledatasets in the list have a n_shards of n before passing over to interleave_datasets().

You can create a sharded IterableDataset by combining datasets using torch.utils.data.chain. This approach works better with DataLoader and multiple workers, avoiding the issues with your generator method.