Is there a way to minimize the GPU memory usage when loading a checkpoint with accelerate?
model = AutoModel.from_pretrained(args.model)
optimizer = torch.optim.AdamW(model.parameters())
model, optimizer = accelerator.prepare(model, optimizer) # This loads the model on the GPU
accelerator.load_state(checkpoint_dir) # This loads the checkpoint weights on the GPU as well
The GPU memory usage is greater when loading a checkpoint, meaning that accelerate doesn’t load the weights in-place. I would like to avoid this extra memory overhead, but haven’t found an official solution. I know about accelerate.init_empty_weights
, but as far as I can tell, it’s not meant to be used with accelerate.prepare
and accelerate.load_state
. Additionally, accelerator.save_state
does not support sharded weights.