Given a list of IterableDatasets datasets (len = n), each with n_shards == 1, is it possible to create a single IterableDataset ds with ds.n_shards == n? I have a really roundabout way of doing so:
def gen(datasets):
print("Creating generator")
for idx, ds in enumerate(datasets):
print(f"Starting dataset {idx}: {ds.name}")
for item_idx, item in enumerate(ds):
yield item
print(f"Ending dataset {idx}: {ds.name}")
print("Generator exhausted")
ds = IterableDataset.from_generator(gen, gen_kwargs={"datasets": datasets})
But this code is not ideal as it doesn’t play well with DataLoader when num_workers > 0.
I had the exact same problem. HF’s datasets.interleave_datasets() can deal with a list of iterabledatasets but the returned iterabledataset will have n_shards being the smallest of the list. In your case it is 1. So the workaround to achieve the goal is to pre-process so that all iterabledatasets in the list have a n_shards of n before passing over to interleave_datasets().
You can create a sharded IterableDataset by combining datasets using torch.utils.data.chain. This approach works better with DataLoader and multiple workers, avoiding the issues with your generator method.