Resume Training with Lower Learning Rate

I’m training a model which started to diverge during the warmup stage of training. I would like to continue training from the last stable checkpoint with a lower learning rate.

I tried passing a new learning rate to TrainingArguments, but the new learning rate is overwritten when the checkpoint is loaded. Currently trainer.train(resume_from_checkpoint='checkpoint') will overwrite the new learning rate with the old learning rate.

How can I resume training from a checkpoint, skip the batches already trained on using the same RNG seed, but also use a lower learning rate? Is this even possible without doing something hacky with the Trainer code?

Additionally, I spend a significant amount of time trying to resume training with different hyperparameters before realizing the new params were being overwritten by resume_from_checkpoint. This behavior is not obvious from a user perspective and should be much more explicit.

4 Likes

Same problem, anyone help?

1 Like

Same here. 1k views and no answer yet

1 Like

To resume training from a checkpoint with a lower learning rate while maintaining the same RNG seed and skipping already trained batches, you have a few options:

  1. Modify the checkpoint: Before loading, you could manually edit the checkpoint file to update the learning rate. This approach is somewhat hacky and requires understanding the checkpoint file structure.
  2. Override after loading: Immediately after loading the checkpoint but before resuming training, you could manually override the optimizer’s learning rate. This method is less intrusive but requires careful timing in your code.
  3. Custom checkpoint loading: Implement a custom checkpoint loading function that restores the model state and optimizer state separately, allowing you to modify the learning rate before fully reconstructing the optimizer.
  4. Use a learning rate scheduler: Instead of directly setting a lower learning rate, use a learning rate scheduler that reduces the rate after resuming from the checkpoint.

Regarding the issue with hyperparameters being overwritten, you’re right that this behavior should be more explicit. Here are some suggestions to improve the user experience:

  1. Add clear documentation: Explicitly state in the resume_from_checkpoint function documentation that it overwrites training arguments.
  2. Implement a warning system: Add a warning message when resume_from_checkpoint overwrites user-specified hyperparameters.
  3. Provide an option: Add a parameter to resume_from_checkpoint that allows users to choose whether to use checkpoint hyperparameters or newly specified ones.
  4. Return overwritten values: Have resume_from_checkpoint return a dictionary of any overwritten values, allowing users to easily see what changed.
  5. Two-step process: Split the process into loading the checkpoint and applying the hyperparameters, giving users more control over what gets overwritten.

These changes would make the behavior more transparent and give users more control over the resumption process

2 Likes