Saving optimizer

Hi sgugger,

Thanks for your reply, I managed to get the model saved and loaded.

On saving optimizer and scheduler, I did something like:

            accelerator.print("==> Saving optimizer <==")
            accelerator.save(optimizer.state_dict(),  "optimizer.pth.tar")
            accelerator.print("==> Saving scheduler <==")
            accelerator.save(scheduler.state_dict(), "scheduler.pth.tar")

Then when I load checkpoint for the optimizer and scheduler right before prepare() (same place as I load the model checkpoint), for example:

        optimizer.load_state_dict(torch.load("optimizer.pth.tar"))
        scheduler.load_state_dict(torch.load("scheduler.pth.tar"))

I get SIGKILL exception from the tpu indicating OOM…

And I thought why not just load on main process? Then use prepare() to push to all process?

      if accelerator.is_main_process:
        optimizer.load_state_dict(torch.load("optimizer.pth.tar"))
        scheduler.load_state_dict(torch.load("scheduler.pth.tar"))

It went okay.

But then when saving the next checkpoint for optimizer and scheduler (model saved okay) I get an error:

Exception in device=TPU:0: 140247764057056
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 305, in __call__
    self.launcher(*args)
  File "<ipython-input-19-05ee9dbbadb6>", line 183, in map_fn
    accelerator.save(optimizer.state_dict(), "optimizer.pth.tar")
  File "/usr/local/lib/python3.7/dist-packages/accelerate/optimizer.py", line 91, in state_dict
    return self.optimizer.state_dict()
  File "/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py", line 120, in state_dict
    for k, v in self.state.items()}
  File "/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py", line 120, in <dictcomp>
    for k, v in self.state.items()}
KeyError: 140247764057056
Exception in device=TPU:7: Cancelled: From /job:tpu_worker/replica:0/task:0:
Cancelled by TearDown.

I am so close to get codes ready to train but out of clues on saving checkpoint… any idea how to save and load optimizer and scheduler? Thanks.

Edit:
Related question?
https://github.com/huggingface/transformers/issues/4963