Given a list of IterableDatasets datasets
(len = n), each with n_shards == 1, is it possible to create a single IterableDataset ds
with ds.n_shards == n
? I have a really roundabout way of doing so:
def gen(datasets):
print("Creating generator")
for idx, ds in enumerate(datasets):
print(f"Starting dataset {idx}: {ds.name}")
for item_idx, item in enumerate(ds):
yield item
print(f"Ending dataset {idx}: {ds.name}")
print("Generator exhausted")
ds = IterableDataset.from_generator(gen, gen_kwargs={"datasets": datasets})
But this code is not ideal as it doesn’t play well with DataLoader when num_workers > 0.