Hi sgugger,
Thanks for your reply, I managed to get the model saved and loaded.
On saving optimizer and scheduler, I did something like:
accelerator.print("==> Saving optimizer <==")
accelerator.save(optimizer.state_dict(), "optimizer.pth.tar")
accelerator.print("==> Saving scheduler <==")
accelerator.save(scheduler.state_dict(), "scheduler.pth.tar")
Then when I load checkpoint for the optimizer and scheduler right before prepare()
(same place as I load the model checkpoint), for example:
optimizer.load_state_dict(torch.load("optimizer.pth.tar"))
scheduler.load_state_dict(torch.load("scheduler.pth.tar"))
I get SIGKILL
exception from the tpu indicating OOM…
And I thought why not just load on main process? Then use prepare()
to push to all process?
if accelerator.is_main_process:
optimizer.load_state_dict(torch.load("optimizer.pth.tar"))
scheduler.load_state_dict(torch.load("scheduler.pth.tar"))
It went okay.
But then when saving the next checkpoint for optimizer and scheduler (model saved okay) I get an error:
Exception in device=TPU:0: 140247764057056
Traceback (most recent call last):
File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
_start_fn(index, pf_cfg, fn, args)
File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
fn(gindex, *args)
File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 305, in __call__
self.launcher(*args)
File "<ipython-input-19-05ee9dbbadb6>", line 183, in map_fn
accelerator.save(optimizer.state_dict(), "optimizer.pth.tar")
File "/usr/local/lib/python3.7/dist-packages/accelerate/optimizer.py", line 91, in state_dict
return self.optimizer.state_dict()
File "/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py", line 120, in state_dict
for k, v in self.state.items()}
File "/usr/local/lib/python3.7/dist-packages/torch/optim/optimizer.py", line 120, in <dictcomp>
for k, v in self.state.items()}
KeyError: 140247764057056
Exception in device=TPU:7: Cancelled: From /job:tpu_worker/replica:0/task:0:
Cancelled by TearDown.
I am so close to get codes ready to train but out of clues on saving checkpoint… any idea how to save and load optimizer and scheduler? Thanks.
Edit:
Related question?
https://github.com/huggingface/transformers/issues/4963