How to ensure the dataset is shuffled for each epoch using Trainer and Datasets?

I am using the Seq2SeqTrainer and pass an datasets.arrow_dataset.Dataset as train_dataset when initiating the object. Is the dataset by default shuffled per epoch? If not, how to make it shuffled?

An example is from the official example: transformers/run_seq2seq.py at master · huggingface/transformers · GitHub

Thanks!

3 Likes

Still needs help…

The Seq2SeqTrainer (as well as the standard Trainer) uses a PyTorch Sampler to shuffle the dataset. At each epoch, it does shuffle the dataset and it also groups the samples of roughly the same length size. You can find the Sampler definition here.

4 Likes

Hi, Is there a parameters that controls whether or not the data get reshuffled before each epoch? And whether or not it is grouped by length? Thanks!

Additionally, if the training is aborted and I’m restarting from a checkpoint - does the checkpoint have information about the shuffling order for this given epoch and which datapoints still haven’t gone through this epoch already? Thanks!

1 Like

No, this is would be very bad practice so we don’t offer that option.

That would be the group_by_length argument.

Yes training will resume with the same shuffle, at the same point you were at the time of the save.

2 Likes

thank you!

Hi Sgugger, why is it a bad practice to reshuffle the dataset at every epoch?

I thought reshuffle the dataset at every epoch can reduce overfitting and improve the generalization performance of the model. By shuffling the dataset, we ensure that the model is exposed to a different sequence of samples in each epoch, which can help to prevent it from memorizing the order of the training data and overfitting to specific patterns.

Shuffling the dataset also helps to improve the diversity of the mini-batches during training, which can improve the robustness of the model and make it more resistant to outliers or noise in the data.

3 Likes

Yes that’s why we don’t offer the option to not reshuffle the data at each epoch for the training set.

5 Likes

be careful with the word “NOT”

1 Like

what are you referring by “that” @sgugger ? Thanks in advance.

I assume that at the very least at the beginning of training, shuffle is true is the recommended approach? e.g., just once with something like this:

    jsonl_files = glob.glob(os.path.expanduser(train_path), recursive=True)
    train_dataset = load_dataset('json', data_files=jsonl_files, split='train')
    train_dataset = train_dataset.select(range(200))
    print(f'->{len(train_dataset)=}')
    _turns_str_2_desired_str = lambda examples: turns_str_2_desired_str(examples, tokenizer, 'turns')
    train_dataset = raw_ds_2_lm_ds_mask_eos_pad_toks(train_dataset, tokenizer, max_length, raw_str_2_desired_str=_turns_str_2_desired_str)
    train_dataset = train_dataset.shuffle(seed=42)

@sgugger Do you agree?

This is a good comment.

1 Like

But why? I thought this would print out the dataset for each epoch, but it prints out the same dataset every time (the shuffled dataset for the first epoch). What did I do wrong?

class LogFirstSamplesCallback(TrainerCallback):
    def on_epoch_begin(self, args, state, control, **kwargs):
        train_dataloader = kwargs.get("train_dataloader")
        dataset = train_dataloader.dataset
        
        print(f"\n🌀 Epoch {int(state.epoch)+1} starts – showing first 2 samples:")

        for i in range(2):
            sample = dataset[i]
            text = tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
            print(f"\nSample {i+1}:\n{text}\n")
trainer.add_callback(LogFirstSamplesCallback())
1 Like

At the same time, I noticed that when I resume training with my previously trained model and start a new round of training, the results are actually better than setting a higher number of epochs or applying various warmup strategies.

I’m not sure if this is because the new training starts with a larger learning rate, allowing it to find a lower loss surface again, or if it’s because each time training is restarted, the data is reshuffled.

1 Like