Reducing `load_state` memory usage

Is there a way to minimize the GPU memory usage when loading a checkpoint with accelerate?

model = AutoModel.from_pretrained(args.model)
optimizer = torch.optim.AdamW(model.parameters())
model, optimizer = accelerator.prepare(model, optimizer) # This loads the model on the GPU
accelerator.load_state(checkpoint_dir) # This loads the checkpoint weights on the GPU as well

The GPU memory usage is greater when loading a checkpoint, meaning that accelerate doesn’t load the weights in-place. I would like to avoid this extra memory overhead, but haven’t found an official solution. I know about accelerate.init_empty_weights, but as far as I can tell, it’s not meant to be used with accelerate.prepare and accelerate.load_state. Additionally, accelerator.save_state does not support sharded weights.

Hi @pratogab, I don’t think we have a way to minimize the GPU memory usage when loading a checkpoint with accelerate. Since each methods (ddp,fsdp,deepspeed) have their own way of loading the model in prepare/load_state_dict, this seems quite complicated to enable this.