Avoid saving deepspeed optimizer and model states at checkpoints

Hello all,

I am using Deepspeed ZeRO-3 optimizer to finetune 7B llama-based model for 3 epochs. After 1st epoch, the model starts to overfit since evaluation loss becomes much higher than previous steps. I would like to check the inference performance of models in each checkpoint (save_strategy = ‘epochs’ already).

However, “_optim_states.pt" and "_model_states.pt” are taking huge place in terms of storage. I’ll never continue to finetune these checkpoints, do i still need to keep them for inference?

I think I won’t need optim_states since it is optimizer-related, what about model_states? I guess hf trainer uses pytorch_model-00001-of-00002.bin and pytorch_model-00002-of-00002.bin during inference, but want to be sure about it

1 Like

It looks really hard…

1 Like

Thanks for the reply! I would like to understand what “_model_states.pt” for? I am removing optim_states (they’re really huge) with this callback, if anyone interested:

class CleanupCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        """
        Deletes optimizer checkpoint files from DeepSpeed checkpoints 
        inside `checkpoint-*/global_step*` folders.
        """
        # Get all DeepSpeed checkpoint folders
        checkpoint_dirs = sorted(glob.glob(os.path.join(args.output_dir, "checkpoint-*/global_step*")), 
                                 key=os.path.getctime, reverse=True)

        if not checkpoint_dirs:
            print("No DeepSpeed checkpoints found for cleanup.")
            return control

        latest_checkpoint = checkpoint_dirs[0]  # Get the latest saved checkpoint

        # Build file paths safely
        optimizer_files = [
            os.path.join(latest_checkpoint, "optimizer.pt"),
            *glob.glob(os.path.join(latest_checkpoint, "*_optim_states.pt")),
        ]

        deleted_count = 0
        for file in optimizer_files:
            try:
                os.remove(file)
                print(f"Deleted: {file}")
                deleted_count += 1
            except Exception as e:
                print(f"Error deleting {file}: {e}")

        # Print summary
        if deleted_count > 0:
            print(f"Deleted {deleted_count} optimizer-related files from: {latest_checkpoint}")

        return control

I am wondering whether I need to keep model_states.pt for inference

1 Like