Resume Training with Lower Learning Rate

I’m training a model which started to diverge during the warmup stage of training. I would like to continue training from the last stable checkpoint with a lower learning rate.

I tried passing a new learning rate to TrainingArguments, but the new learning rate is overwritten when the checkpoint is loaded. Currently trainer.train(resume_from_checkpoint='checkpoint') will overwrite the new learning rate with the old learning rate.

How can I resume training from a checkpoint, skip the batches already trained on using the same RNG seed, but also use a lower learning rate? Is this even possible without doing something hacky with the Trainer code?

Additionally, I spend a significant amount of time trying to resume training with different hyperparameters before realizing the new params were being overwritten by resume_from_checkpoint. This behavior is not obvious from a user perspective and should be much more explicit.

1 Like