Restoring the state of the DataLoader using skip_first_batches() after first epoch

We can use skip_first_batches() to restore the state of the dataloader (as mentioned here).

But what if we use shuffle=True and want to restore the dataloader’s position after several epochs, which would involve multiple reshufflings? How can we ensure the same data order after calling accelerator.load_state() and then skip the correct number of batches?

For example:

  1. If we use shuffle=False in the dataloader, the data is ordered the same way in every epoch:
  • 1st epoch: 1 2 3 4 5
  • 2nd epoch: 1 2 3 4 5
    So, in this case, we don’t need to worry about shuffling and can simply remember how many batches were processed in the current epoch.
  1. However, if we use shuffle=True, the order of the data changes with each epoch:
  • 1st epoch: 4 2 3 1 5
  • 2nd epoch: 5 1 2 3 4
    Suppose we stopped halfway through the second epoch. How can we restore this exact state using accelerator.load_state() and accelerator.skip_first_batches()?
1 Like