You should just use accelerator.save_state
.
To add in the other items you want to save, do accelerator.register_for_checkpointing({...})
with that dictionary containing everything else, such as the hyperparams, sha, and samples. (As the RNG states are automatically saved already through save_state
).
See the docs for more info: