Saving optimizer

From documentation prepare() is used to send model, optimizer, data loaders to each TPU core etc…

If I want to save the model I will unwrap the model first by doing unwrap_model().

Is there a function to unwrap the optimizer? I assume unwrap_model() is just a function to detach from all cores? So I can use it also like unwrap_model(optimizer)


The wrapper of the optimizer in the Accelerate library does not add anything to its state dictionary, so you can just call optimizer.state_dict and save it.

Hi sgugger,

Thanks for your reply, I managed to get the model saved and loaded.

On saving optimizer and scheduler, I did something like:

            accelerator.print("==> Saving optimizer <==")
  ,  "optimizer.pth.tar")
            accelerator.print("==> Saving scheduler <==")
  , "scheduler.pth.tar")

Then when I load checkpoint for the optimizer and scheduler right before prepare() (same place as I load the model checkpoint), for example:


I get SIGKILL exception from the tpu indicating OOM…

And I thought why not just load on main process? Then use prepare() to push to all process?

      if accelerator.is_main_process:

It went okay.

But then when saving the next checkpoint for optimizer and scheduler (model saved okay) I get an error:

Exception in device=TPU:0: 140247764057056
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/", line 305, in __call__
  File "<ipython-input-19-05ee9dbbadb6>", line 183, in map_fn, "optimizer.pth.tar")
  File "/usr/local/lib/python3.7/dist-packages/accelerate/", line 91, in state_dict
    return self.optimizer.state_dict()
  File "/usr/local/lib/python3.7/dist-packages/torch/optim/", line 120, in state_dict
    for k, v in self.state.items()}
  File "/usr/local/lib/python3.7/dist-packages/torch/optim/", line 120, in <dictcomp>
    for k, v in self.state.items()}
KeyError: 140247764057056
Exception in device=TPU:7: Cancelled: From /job:tpu_worker/replica:0/task:0:
Cancelled by TearDown.

I am so close to get codes ready to train but out of clues on saving checkpoint… any idea how to save and load optimizer and scheduler? Thanks.

Related question?

Make sure you have a source install of Accelerate as there was a bug in the optimizer state reloading fixed last week. Then can you try to load the state dictionary on the CPU? That should avoid the OOM:

optimizer.load_state_dict(torch.load("optimizer.pth.tar", map_location="cpu"))

You should load the state in all processes as there is nothing that will synchronize them otherwise.

We also have utilities to make checkpointing easier on the roadmap, so it will soon get easier hopefully!

So I source install and compile Accelerate and XLA:

! pip install cloud-tpu-client==0.10

! pip install git+

used map_location="cpu" still returns SIGKILL OOM exception… :confused:

That’s what we use in the Trainer, so I’m a bit out of options. Do you have a minimal reproducer we could investigate?

Hi sugger here is an example colab workbook. Please use a VM with larger ram like 32gb:

So if you run the code from there it shall save all weights in the first loop. Then stop the loop and change the flag FROM_CHECKPOINT to True. rerun the accelerate_launcher() it shall then load the weights.


I get the SIGKILL signal at the first loop when executing your notebook.

Hi sgugger, sorry for the inconvenience.

Can you delete all the code in this notebook and replace the code with the SIGKILL one?

This notebook is from pytorchXLA demo and have an instance with more ram 35gb so it should work (for debugging purpose). The standard colab one only have 15gb.


The RAM is not dependent on the notebook but on the user. I have no access to a notebook with 35Gb of RAM. In general it looks like you are using a model that just uses too much RAM for Google Colab.

Hi sgugger,

I learn this from reading their github, they instantiate a VM with 35gb for demo so their demo won’t crash. Since the backend of yours is also pytorchXLA and are debugging a problem so I believe it is ok to use


But anyway, if I use this method, I can load and save model, optimizer, scheduler with no problem. But when restarting from checkpoint model loads ok but optimizer, scheduler always shows OOM.

I can try with a larger RAM instance in GCP

Hi sgugger, I have now setup the environment for GCP TPU…

I wonder why in accelerate config it does not ask to specify the IP address of the TPU nodes?
I then run accelerate test, not surprisingly I get TPU connection timeout Failed to connect to client mesh master because it does not know where the TPU node is.

Background info:
I setup the VM with pytorch xla image
and added the TPU address

export XRT_TPU_CONFIG="tpu_worker;0;$TPU_IP_ADDRESS:8470"

and ran PytorchXLA test script in their tutorial with TPU acceleration successfully.

You have to do the export, that does set up the IP address of the TPU. On my side could run successfully my scripts after doing that step.

Thanks it works!

It trains but same problem with loading optimizer when in checkpoint restart. The optimizer file is only 1.1Gb big and when it gets loaded RAM usage shoots up to 64gb! (The VM has 64gb) and sigh… SIGKILL

There must be memory issue somewhere but anyway I will try a larger RAM instance now maybe 128gb.

Would it be where optimizer model is loaded into TPUs memory but when the new weights are loaded the old weights in the model don’t get freed?

That’s definitely weird. The old weights should be released when the new ones are loaded inside. What’s happening is probably the combination of both weights being temporary in RAM in all the 8 processes at the same time. It’s annoying we have to pass through the CPU for TPUs checkpointing (on GPUs we could jsut load directly on the GPU and avoid this issue).

Thanks very much again sgugger, so here is my update:

  1. Before I have the model (AutoModel.from_pretrained()) and tokenizer (AutoTokenizer.from_pretrained()) saved into the local disk and load from it. It saves the state of the model, optimizer, scheduler. it just won’t reload the states from checkpoint, OOM SIGKILL etc…

  2. Now if I download the model, tokenizer from the internet every time, it reloads from checkpoints but when it comes to saving the states it said to have exceptions in all TPUs… I have no idea why it could be the optimizer in all TPU aren’t sync.

This gets me wondered if my code have bug or is just not working, i.e. defining the custom model class like one of the post here before where he failed to reload the states too? I remove the accelerate package and run only in CPU, it trains, load and save so this confirms the code is fine…

I am out of clues really. The only thing I could try now is not to use accelerate and rolls back to my prior code with Pytorch XLA. During that time I haven’t write any checkpoint yet so it should be interesting to see what happens. Thanks.

Managed to load and save optimizer and scheduler state with my original Pytorch XLA script!

This confirms some issue with accelerate

Let me know if you want to follow up sgugger i can send you both of my XLA and accelerate script! Thanks for your help! :smiley:

Thanks for the info! If you can share gists with each version, I’d happily look and try to find the issue when I have some bandwidth!