Making an infinite IterableDataset

I’m trying to unify some training code across different mixtures of hf datasets, and it would be nice to make an IterableDataset an infinite length and combine them together with interleave_datasets (sort of like webdataset’s resample=True).

1 Like

Hi ! With datasets>=3.3 you can use IterableDataset.repeat(), for example:

infinite_iterable_dataset = iterable_dataset.repeat()
1 Like

This sounds great. Just to be sure, how does this interact with interleave_datasets, split_by_node and shuffle? I’m thinking splitting first, repeat each part, shuffling each part, followed by interleaving, to ensure each rank receives infinite but non-overlapping data?

1 Like

split_by_node ensures that each rank receives different data, independently of the other operations :slight_smile:

So you can definitely do repeat + interleave_datasets +split_by_node + shuffle in that order.

PS: as of today there is an issue with interleave_datasets which overwrites split_by_node and shuffle… So if you can you should call split_by_node and shuffle after interleave_datasets, see interleave_datasets resets shuffle state · Issue #7156 · huggingface/datasets · GitHub

It’s possible to fix the issue in iterable_dataset.py in datasets though, if you’d like to open a PR. The issue is that _interleave_iterable_datasetsdoesn’t pass the shuffling and distributed arguments to the resulting IterableDataset

1 Like

It seems that on 3.4.1, following the order of repeat + interleave_datasets + shuffle there’s this issue:

NotImplementedError: <class 'datasets.iterable_dataset.RepeatExamplesIterable'> doesn't implement num_shards yet

Moving repeat to the end seems to result in the same error, and just doing interleave+shuffle works on their own. :face_with_raised_eyebrow: Is this a newly introduced issue? I can’t seem to find this anywhere. I’ll include the full trace here:

    for b in tqdm.tqdm(loader):
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
    for obj in iterable:
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1465, in _next_data
    return self._process_data(data)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1491, in _process_data
    data.reraise()
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/_utils.py", line 715, in reraise
    raise exception
NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 2252, in __iter__
    yield from self._iter_pytorch()
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 2132, in _iter_pytorch
    if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers:
                                   ^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1911, in num_shards
    return self.ex_iterable.num_shards
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1562, in num_shards
    return self.ex_iterable.num_shards
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 737, in num_shards
    return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 737, in <genexpr>
    return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
               ^^^^^^^^^^^^^^^^^^^^^^
  File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 183, in num_shards
    raise NotImplementedError(f"{type(self)} doesn't implement num_shards yet")
NotImplementedError: <class 'datasets.iterable_dataset.RepeatExamplesIterable'> doesn't implement num_shards yet
1 Like

Ah it looks like it was forgotten during the implementation -_-
The num_shards should be the same as the underlying examples iterable.
Would you like to open a PR to fix this ?

1 Like

Yeah of course. I’ll get to this tonight or tomorrow after work

2 Likes