With this code i have error with accelerator.get_state_dict(optimizer) so i guess it is not how it is intended to work, still i cant find example what is intended way.
def save_checkpoint(model_state_dict, optimizer_state_dict,
samples, hp, checkpoint_path, accelerator):
checkpoint = {
'model_state_dict': model_state_dict,
'optimizer_state_dict': optimizer_state_dict,
'samples': samples,
'cuda_rng_state_all': torch.cuda.get_rng_state_all(),
'random_rng_state': torch.random.get_rng_state(),
'hp': hp,
'sha': git.Repo(search_parent_directories=True).head.object.hexsha,
}
accelerator.save(checkpoint, checkpoint_path)
accelerator.wait_for_everyone()
save_checkpoint(accelerator.get_state_dict(model),
accelerator.get_state_dict(optimizer),
samples, hp, args.checkpoint_path, accelerator)