We can use skip_first_batches()
to restore the state of the dataloader (as mentioned here).
But what if we use shuffle=True
and want to restore the dataloader’s position after several epochs, which would involve multiple reshufflings? How can we ensure the same data order after calling accelerator.load_state()
and then skip the correct number of batches?
For example:
- If we use
shuffle=False
in the dataloader, the data is ordered the same way in every epoch:
- 1st epoch: 1 2 3 4 5
- 2nd epoch: 1 2 3 4 5
So, in this case, we don’t need to worry about shuffling and can simply remember how many batches were processed in the current epoch.
- However, if we use
shuffle=True
, the order of the data changes with each epoch:
- 1st epoch: 4 2 3 1 5
- 2nd epoch: 5 1 2 3 4
Suppose we stopped halfway through the second epoch. How can we restore this exact state usingaccelerator.load_state()
andaccelerator.skip_first_batches()
?