How to make Trainer train the model one epoch at a time?

Hello! This is my first question, so I hope I’m posting in the right place.

Short summary:
How do I tell a Trainer, already intialized, to train for a single epoch at a time?

Long version:
I need to modify the training data during the training process at each epoch. I have a Trainer object that is initialized with a fixed number of epochs (say 10) and a learning rate schedule.
What I would like to do would look something like:

for i in range(epochs):
    data = modify_data()
    trainer.train_dataset = data["train"]

If I just set the num_train_epochs parameter to 1 in TrainingArguments, the learning rate scheduler will bring the learning rate to 0.0 between two epochs, making training useless after the first epoch.
If I just create a new Trainer at each iteration I lose the state of the learning rate schedule.

I have tried looking at the available callbacks, but I couldn’t see any way to make the training stop at each epoch and then resume. Maybe I’m missing something? If anyone knew how to do this I’d be really grateful.

Thanks in advance!

After some trial and error I found a solution that works for me. I’ll leave it here in case anyone else is in the same situtation.

I ended up defining a TrainerCallback that uses on_epoch_end() to set control.should_training_stop = False.

    # Stop training after each epoch
    class StopCallback(TrainerCallback):
        def on_epoch_end(self, args, state, control, logs=None, **kwargs):
            control.should_training_stop = True

With this, when I call trainer.train() again in the loop, it continues training maintaining the state of the learning schedule.

The only drawback is that saving has to handled separately or the Trainer will overwrite the same save path every time (the step count is the same at each epoch).