Checkpoint vs model weight

Please make me clear difference between checkpoint and saving the weights of the model,
which one can I use to load later?
Also I could not find my checkpoints (may be overwrite option at my end), so the same can done via these line of code

trainer.save_model(“/content/drive//results/distillbert/trainer”)

tokenizer.save_pretrained(“/content/drive/results/distillbert/tokenizer”)

I think a “checkpoint” is what we call a partial save during training.

To take a checkpoint during training, you can save the model’s state_dict, which is a list of the current values of all the parameters that have been updated during this training run.
Note that this doesn’t save the non-variable parameters, and it doesn’t save the weights in any frozen layers.

To reload the model to that checkpoint state, you first of all have to load a complete model with the right configuration. You can do this either by initializing randomly with the config file, or by loading a suitable pre-trained model. Then you update that complete model with the saved state_dict weights.

If you want to continue the training from the same point, you also need information about the scheduler and the optimizer. This can be saved and applied using the optimizer’s state_dict.

I haven’t any examples of using save_model or save_pretrained, but here’s an example of saving a model and optimizer during training.

filedt = datetime.datetime.now().strftime(“%Y%m%d-%H%M%S”)
torch.save(model.state_dict(),‘/content/drive/My Drive/ftregmod-’ + filedt)
torch.save(optimizer.state_dict(),‘/content/drive/My Drive/ftregopt-’ + filedt)

and then to reload and continue training:

READFROMNAMEMODEL = ‘/content/drive/My Drive/ftregmod-20200911-014657’ ####
READFROMNAMEOPT = ‘/content/drive/My Drive/ftregopt-20200911-014657’ ####

model = BertForSequenceClassification.from_pretrained(‘bert-base-uncased’,
num_labels=NCLASSES,
output_attentions=True)

model.load_state_dict(torch.load(READFROMNAMEMODEL), strict=False)

optimizer = AdamW(model.parameters(),
lr = LEARNRATE, # default is 5e-5
eps = 1e-8 # default is 1e-8.
)

optimizer.load_state_dict(torch.load(READFROMNAMEOPT))

2 Likes

thanks @rgwatwormhill