Restoring from a checkpoint when training on a large dataset with streaming

Hello!

I’m trying to find a reasonable solution for restoring LLM training when training on a large dataset in streaming mode. In a nutshell what I need to do is to skip to a batch given its index.

The only recommended solution that I’ve found here so far is to just use skip to go to the required batch. However, as far as I understand, that would require re-running data pipeline for all previous batches. So if my training crashed after processing 100Gb of data, I’d need to re-download and re-process all 100Gb before I can continue training, which would be very time-consuming.

For a sharded dataset a more reasonable solution would be to manually iterate over shards, while streaming data within each shard. In this case, when restoring from checkpoint, I can just restore the current shard index and then run skip within the shard, which should be much faster. However, I cannot find any way to achieve that using HF API.

Is it something that’s currently possible? If so, how can I implement it or something similar?

A similar question has been asked and answered here: Offer an alternative to Iterable Dataset that allows lazy loading and processing while skipping batches efficiently · Issue #5905 · huggingface/datasets · GitHub