How to save everything in one checkpoint?

With this code i have error with accelerator.get_state_dict(optimizer) so i guess it is not how it is intended to work, still i cant find example what is intended way.

def save_checkpoint(model_state_dict, optimizer_state_dict,
                    samples, hp, checkpoint_path, accelerator):
    checkpoint = {
        'model_state_dict': model_state_dict,
        'optimizer_state_dict': optimizer_state_dict,
        'samples': samples,
        'cuda_rng_state_all': torch.cuda.get_rng_state_all(),
        'random_rng_state': torch.random.get_rng_state(),
        'hp': hp,
        'sha': git.Repo(search_parent_directories=True).head.object.hexsha,
    }
    accelerator.save(checkpoint, checkpoint_path)

accelerator.wait_for_everyone()
save_checkpoint(accelerator.get_state_dict(model),
                accelerator.get_state_dict(optimizer),
                samples, hp, args.checkpoint_path, accelerator)

You should just use accelerator.save_state.

To add in the other items you want to save, do accelerator.register_for_checkpointing({...}) with that dictionary containing everything else, such as the hyperparams, sha, and samples. (As the RNG states are automatically saved already through save_state).

See the docs for more info:

Hmmm, seems like i cant just save any dict, but need special object with defined methods, and it doesnt sound like very convenient.
Any better way around?

  File "/opt/conda/lib/python3.10/site-packages/accelerate/accelerator.py", line 2516, in register_for_checkpointing
    raise ValueError(err)
ValueError: All `objects` must include a `state_dict` and `load_state_dict` function to be stored. The following inputs are invalid:
	- Item at index 0, `dict`
1 Like