Resume Training, but reset epochs

I want to change the training dataset after each epoch. For this, I just have a loop, where I create a new Trainer object and after the first epoch call Trainer.train with resume_from_checkpoint=True. For each dataset, the training should be performed for a single epoch. However, in this loop, the training so not continued, because when resuming, the number of epochs the model has already been trained for is also re-loaded. How can i reset the number of epochs?

for i, k in enumerate(sorted(train_datasets.keys())):
    train_dataset = train_datasets[k]
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_datasets=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        data_collator=data_collator,
    )
    trainer.train(resume_from_checkpoint=i > 0)
1 Like