Hello!
I’m trying to find a reasonable solution for restoring LLM training when training on a large dataset in streaming mode. In a nutshell what I need to do is to skip to a batch given its index.
The only recommended solution that I’ve found here so far is to just use skip
to go to the required batch. However, as far as I understand, that would require re-running data pipeline for all previous batches. So if my training crashed after processing 100Gb of data, I’d need to re-download and re-process all 100Gb before I can continue training, which would be very time-consuming.
For a sharded dataset a more reasonable solution would be to manually iterate over shards, while streaming data within each shard. In this case, when restoring from checkpoint, I can just restore the current shard index and then run skip
within the shard, which should be much faster. However, I cannot find any way to achieve that using HF API.
Is it something that’s currently possible? If so, how can I implement it or something similar?