When training a model with something like:
model = EncoderDecoderModel.from_pretrained("super-seq2seq-model")
# set training arguments - these params are not really tuned, feel free to change
training_args = Seq2SeqTrainingArguments(
output_dir="path/to/mymodel/",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
logging_steps=500, # set to 1000 for full training
save_steps=500, # set to 500 for full training
eval_steps=500, # set to 8000 for full training
warmup_steps=2000, # set to 2000 for full training
max_steps=16, # delete for full training
overwrite_output_dir=True,
save_total_limit=3,
fp16=True,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
)
It will create the model checkpoints, e.g.
$ ls path/to/mymodel/
checkpoint-4500 checkpoint-5000 checkpoint-5500
And when the model training ends, I reload the model and continue to try:
model = EncoderDecoderModel.from_pretrained("path/to/mymodel/checkpoint-5500")
training_args = Seq2SeqTrainingArguments(
output_dir="path/to/mymodel/",
evaluation_strategy="steps",
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
predict_with_generate=True,
logging_steps=500, # set to 1000 for full training
save_steps=500, # set to 500 for full training
eval_steps=500, # set to 8000 for full training
warmup_steps=2000, # set to 2000 for full training
max_steps=16, # delete for full training
overwrite_output_dir=True,
save_total_limit=5,
fp16=True,
)
# instantiate trainer
trainer = Seq2SeqTrainer(
model=multibert,
tokenizer=tokenizer,
args=training_args,
train_dataset=train_data,
eval_dataset=val_data,
)
When saving the model in the continued / 2nd training round wave, the checkpoints were reset and starts from 500 again, i.e.
$ ls path/to/mymodel/
checkpoint-500 checkpoint-5000 checkpoint-5500
Is that the expected behavior of the model saving?
I could try to save it to a different directory in the 2nd training but is there some way / args in Seq2SeqTrainingArguments
to tell it to continue checkpoint counter from 5500 + 500?