How to access order of shards in streaming IterableDataset

I’m training a dataset on the Pile, using streaming. As far as I can tell, the Pile is split into 30 shards (different files), and when I create a dataset and shuffle it, the order of the shards is also shuffled. But I can’t tell how to access what that new order is, nor how to set my own order.

This is annoying because when my training run is interrupted (eg crashes) and I want to resume it from a checkpoint, I want to ensure the model does not see old data, but don’t know how to ensure this (my best guess is to reshuffle the data and hope I get lucky?)

1 Like

To restart from where you were, you can use the same shuffling seed and use .skip(n_examples). It may be slow because it iterates over all the previous examples before reaching the right one though.

Anyway there’s no public API to know the order of the shards right know, but it’s possible using some internal (and that could change in the future) functions:

ds = ds.shuffle(...)
ex_iterable = ds._ex_iterable
# and then you can access ex_iterable.ex_iterable or ex_iterable.ex_iterables recursively until
# you reach a ShardShuffledExamplesIterable object
from datasets.iterable_dataset import _shuffle_kwargs
from copy import deepcopy
rng = deepcopy(shard_shuffled_ex_iterable.generator)
shuffled_kwargs = _shuffle_kwargs(rng, shard_shuffled_ex_iterable.kwargs)
# one of the kwargs contains the list of shuffled shards