I’m trying to unify some training code across different mixtures of hf datasets, and it would be nice to make an IterableDataset an infinite length and combine them together with interleave_datasets
(sort of like webdataset’s resample=True
).
Hi ! With datasets>=3.3
you can use IterableDataset.repeat(), for example:
infinite_iterable_dataset = iterable_dataset.repeat()
This sounds great. Just to be sure, how does this interact with interleave_datasets
, split_by_node
and shuffle
? I’m thinking splitting first, repeat each part, shuffling each part, followed by interleaving, to ensure each rank receives infinite but non-overlapping data?
split_by_node
ensures that each rank receives different data, independently of the other operations
So you can definitely do repeat
+ interleave_datasets
+split_by_node
+ shuffle
in that order.
PS: as of today there is an issue with interleave_datasets
which overwrites split_by_node
and shuffle
… So if you can you should call split_by_node
and shuffle
after interleave_datasets
, see interleave_datasets resets shuffle state · Issue #7156 · huggingface/datasets · GitHub
It’s possible to fix the issue in iterable_dataset.py
in datasets
though, if you’d like to open a PR. The issue is that _interleave_iterable_datasets
doesn’t pass the shuffling
and distributed
arguments to the resulting IterableDataset
It seems that on 3.4.1, following the order of repeat
+ interleave_datasets
+ shuffle
there’s this issue:
NotImplementedError: <class 'datasets.iterable_dataset.RepeatExamplesIterable'> doesn't implement num_shards yet
Moving repeat to the end seems to result in the same error, and just doing interleave+shuffle works on their own. Is this a newly introduced issue? I can’t seem to find this anywhere. I’ll include the full trace here:
for b in tqdm.tqdm(loader):
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/tqdm/std.py", line 1181, in __iter__
for obj in iterable:
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 701, in __next__
data = self._next_data()
^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1465, in _next_data
return self._process_data(data)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1491, in _process_data
data.reraise()
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/_utils.py", line 715, in reraise
raise exception
NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 351, in _worker_loop
data = fetcher.fetch(index) # type: ignore[possibly-undefined]
^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
data.append(next(self.dataset_iter))
^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 2252, in __iter__
yield from self._iter_pytorch()
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 2132, in _iter_pytorch
if self._is_main_process() and ex_iterable.num_shards < worker_info.num_workers:
^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1911, in num_shards
return self.ex_iterable.num_shards
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 1562, in num_shards
return self.ex_iterable.num_shards
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 737, in num_shards
return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 737, in <genexpr>
return min(ex_iterable.num_shards for ex_iterable in self.ex_iterables)
^^^^^^^^^^^^^^^^^^^^^^
File "/gpfs/data/oermannlab/users/xl3942/.conda/envs/simdino/lib/python3.11/site-packages/datasets/iterable_dataset.py", line 183, in num_shards
raise NotImplementedError(f"{type(self)} doesn't implement num_shards yet")
NotImplementedError: <class 'datasets.iterable_dataset.RepeatExamplesIterable'> doesn't implement num_shards yet
Ah it looks like it was forgotten during the implementation -_-
The num_shards should be the same as the underlying examples iterable.
Would you like to open a PR to fix this ?
Yeah of course. I’ll get to this tonight or tomorrow after work