Save only best model in Trainer

I have read previous posts on the similar topic but could not conclude if there is a workaround to get only the best model saved and not the checkpoint at every step, my disk space goes full even after I add savetotallimit as 5 as the trainer saves every checkpoint to disk at the start.
Please suggest.
Thanks

2 Likes

You can set save_strategy to NO to avoid saving anything and save the final model once training is done with trainer.save_model().

2 Likes

Thank you, this is helpful.

I am using the below -

  args = TrainingArguments(
      output_dir=f"./out_fold{i}",
      overwrite_output_dir = 'True',
      evaluation_strategy="steps",
      eval_steps=40,
      logging_steps = 40,
      learning_rate = 5e-5,
      per_device_train_batch_size=8,
      per_device_eval_batch_size=8,
      num_train_epochs=10,
      seed=0,
      save_total_limit = 1,
      # report_to = "none",
  #     logging_steps = 'epoch',
      load_best_model_at_end=True,
      save_strategy = "no"
  )
  trainer = Trainer(
      model=model,
      args=args,
      train_dataset=train_dataset,
      eval_dataset=val_dataset,
      # compute_metrics=compute_metrics,
      
      # callbacks=[EarlyStoppingCallback(early_stopping_pa)],
  )
  trainer.train()
  trainer.save_model(f'out_fold{i}')

Here, thought save_strategy = “no” , the checkpoints are saved at start in disk (as below) due to which disk goes full. Can you suggest what’s going wrong?

***** Running Evaluation *****
Num examples = 567
Batch size = 8
Saving model checkpoint to ./out_fold0/checkpoint-40
Configuration saved in ./out_fold0/checkpoint-40/config.json
Model weights saved in ./out_fold0/checkpoint-40/pytorch_model.bin
Deleting older checkpoint [out_fold0/checkpoint-760] due to args.save_total_limit
***** Running Evaluation *****

1 Like

You can’t use load_best_model_at_end=True if you don’t want to save checkpoints: it needs to save checkpoints at every evaluation to make sure you have the best model, and it will always save 2 checkpoints (even if save_total_limit is 1): the best one and the last one (to resume an interrupted training).

6 Likes

If save_total_limit is set to some value, will checkpoints be replaced by newer checkpoints even if the newer checkpoints are underperforming?

The best checkpoint is always kept, as is the last checkpoint (to make sure you can resume training from it).

3 Likes

Thanks @sgugger

I believe we have to set these parameters which will save 2 checkpoints (best one and last one) and to avoid saving checkpoints at every evaluation. Is that right?

save_total_limit = 2
save_strategy = “no”
load_best_model_at_end=False

@Vinayaks117, did those settings work for you to save the most recent and the best? I’d like to do the same.

@jbmaxwell Yes.

2 Likes

Great, thanks!

tbh it didn’t work for me, (version 4.20.1) not sure what i’m missing.

So instead i’m running a cron job every 15 min to clean up those checkpoints.

file handlers limit gets reached after 10hrs ,so as a last resort the cron job is cleaning those files

My bad it works, it was an issue with how i passed params through sys.argv

@sgugger @prachi12 how do you know which checkpoint had the best performance and how do you load that specific checkpoint?

2 Likes

@artificial-cerebral I had the same question, I couldn’t find the answer from the documentation. After checking the source code, I found :

# Define your trainer, etc...
trainer.train()

# After training, access the path of the best checkpoint like this
best_ckpt_path = trainer.state.best_model_checkpoint
2 Likes

@astariul - Thanks! That seems to be working. For future readers, I think another option would be to remove the whole directory (rm -r $save_path) after training, and then do trainer.save_model(). It seems that this way it saves only the best model (assuming you had enabled load_best_model=True). Alternatively, if you don’t want to delete the checkpoints, then you can avoid rm -r $save_path, and provide a new output_dir path to trainer.save_model(output_dir=new_path).

Hey @artificial-cerebral , can you please share a code example of how you do that?