I want to change the training dataset after each epoch. For this, I just have a loop, where I create a new Trainer object and after the first epoch call
resume_from_checkpoint=True. For each dataset, the training should be performed for a single epoch. However, in this loop, the training so not continued, because when resuming, the number of epochs the model has already been trained for is also re-loaded. How can i reset the number of epochs?
for i, k in enumerate(sorted(train_datasets.keys())): train_dataset = train_datasets[k] trainer = Seq2SeqTrainer( model=model, args=training_args, train_datasets=train_dataset, eval_dataset=eval_dataset, tokenizer=tokenizer, data_collator=data_collator, ) trainer.train(resume_from_checkpoint=i > 0)