Resume_from_checkpoint & skipping batches, why does the processing function need to be run for skipped batches?

hello, i’m having trouble starting from a checkpoint fast because it seems like transformers is running the complete data processing pipeline despite not making use of the first n batches. in my case, it takes about an hour to resume training before it “skips” first n batdhes.

my pipeline is pretty straightforward: load data, augment if necessary, tokenize and return the sample.

i’m hoping someone could shed a light on what’s the rationale behind incrementing the index (even for iterable datasets). is it due to random seeds and incrementing without running the inside preprocessing loop would lead to reproducibility issues? if so, and if i accept that, how can i skip ahead without the --ignore_data_skip (because i don’t want to start from batch 0 but where i’ve left off)

1 Like

We recently merged a fix for that. You will need an isntallation of transformers from source and the last version of Accelerate.
Note that this only works for non-iterable dataset.

thanks @sgugger , any reason why it doesn’t work for non-shuffled iterables?

We can’t index into iterable datasets, the only way to skip iterations is to run though them.

1 Like

@sgugger A better solution is to pass some flag to the CustomDataset telling it that we are now in a ‘skip’ mode, so that the CustomDataset could prevent from doing expensive and unneeded steps in the skipping phase, like for example tokenize words, and then it will be x100 faster I guess.

@sgugger can you explain (maybe with a short example) how to use the solution that you mentioned above in the new version?

Is regular Pandas Dataframe considered to be ‘non-iterable’?

Feel free to implement your solution on your dataset. The Trainer cannot know in advance that your dataset has some custom skip flag as it works with any PyTorch dataset.

For the new version, you just need to have installs from source of Transformers and Accelerate.

Thanks, I just implemented my idea, and it works, speeds up the skipping phase by x1000 at least - which makes it workable.
The solution goes like this:

  1. Patch the trainer.py by adding only one line after:
    steps_trained_in_current_epoch -= 1
    add
    epoch_iterator.dataset.steps_trained_in_current_epoch = steps_trained_in_current_epoch
  2. In your class CustomDataset(Dataset), init add:
    self.steps_trained_in_current_epoch = 0
  3. in your class CustomDataset(Dataset), getitem add:
    if self.steps_trained_in_current_epoch > 0:
    return {‘input_ids’:
    and only then do the tokenization (which takes the 99% of the time during skipping)

Works great.

Regarding what you suggested, right now I do not use Accelerate. I must use it for the ‘skipping’ feature to work in the new transformer?

This functionality is particularly useful for large datasets and large datasets are almost always iterables - because they are large :slight_smile:

There is a solution for this in tf using tf.train.Checkpoint  |  TensorFlow v2.12.0 which I’ve used during running some jobs with t5x repo and found it particularly helpful (e.g. preemptible training runs). I will take at how this can be used in HF scripts in a generalized way.

class MyIterableDataset(torch.utils.data.IterableDataset):
    def __init__(self, tf_dataset_iterator, ckpt_path="/tmp/train_ds"):
        super(MyIterableDataset).__init__()
        self.tf_dataset_iterator = tf_dataset_iterator
        self.ckpt_path = ckpt_path
        self.checkpoint = tf.train.Checkpoint(self.tf_dataset_iterator)
        
    def __iter__(self):
        for ex in self.tf_dataset_iterator:
            yield ex.numpy()
            
    def save(self):
        """Save underlying tf.data.Dataset iterator"""
        """https://www.tensorflow.org/api_docs/python/tf/train/Checkpoint#args"""
        self.checkpoint.save(self.ckpt_path)
        
    def load(self,index=1,ckpt_path=None):
        """Restore underlying tf.data.Dataset iterator"""
        if ckpt_path is None:
            self.checkpoint.restore(f"{self.ckpt_path}-{index}")
        else:
            self.checkpoint.restore(ckpt_path)