How to save model in Colab during TPU training with Accelerate

I’m training a model on a TPU on Colab and my code is based on this example (with the training loop wrapped in a large function): accelerate/nlp_example.py at main · huggingface/accelerate · GitHub

How do I save to model trained on the TPU to my google drive?

I understand that I need to run a code like this in the training function:

accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model.state_dict(), f'./results/{training_directory}')

But then I get the error below. I know how to save a model when it’s in my environment, but I’m not sure how to save it when it’s in the large training function and on the TPU.

Exception in device=TPU:0: [Errno 21] Is a directory: './results/nli-few-shot/TPU/'
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 570, in __call__
    self.launcher(*args)
  File "<ipython-input-84-275217ef36aa>", line 97, in training_function
    accelerator.save(unwrapped_model.state_dict(), f'./results/{training_directory}')
  File "/usr/local/lib/python3.7/dist-packages/accelerate/accelerator.py", line 507, in save
    save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 544, in save
    xm.save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 818, in save
    torch.save(cpu_data, file_or_path)
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 376, in save
    with _open_file_like(f, 'wb') as opened_file:
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 230, in _open_file_like
    return _open_file(name_or_buffer, mode)
  File "/usr/local/lib/python3.7/dist-packages/torch/serialization.py", line 211, in __init__
    super(_open_file, self).__init__(open(name, mode))
IsADirectoryError: [Errno 21] Is a directory: './results/nli-few-shot/TPU/'
Exception in device=TPU:6: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
Traceback (most recent call last):
Exception in device=TPU:5: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
Exception in device=TPU:7: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 329, in _mp_start_fn
    _start_fn(index, pf_cfg, fn, args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 570, in __call__
    self.launcher(*args)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 323, in _start_fn
    fn(gindex, *args)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 570, in __call__
    self.launcher(*args)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 570, in __call__
    self.launcher(*args)
  File "<ipython-input-84-275217ef36aa>", line 97, in training_function
    accelerator.save(unwrapped_model.state_dict(), f'./results/{training_directory}')
  File "<ipython-input-84-275217ef36aa>", line 97, in training_function
    accelerator.save(unwrapped_model.state_dict(), f'./results/{training_directory}')
  File "/usr/local/lib/python3.7/dist-packages/accelerate/accelerator.py", line 507, in save
    save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/accelerator.py", line 507, in save
    save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 544, in save
    xm.save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 819, in save
    rendezvous('torch_xla.core.xla_model.save')
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 544, in save
    xm.save(obj, f)
  File "<ipython-input-84-275217ef36aa>", line 97, in training_function
    accelerator.save(unwrapped_model.state_dict(), f'./results/{training_directory}')
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 819, in save
    rendezvous('torch_xla.core.xla_model.save')
  File "/usr/local/lib/python3.7/dist-packages/accelerate/accelerator.py", line 507, in save
    save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 863, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
  File "/usr/local/lib/python3.7/dist-packages/accelerate/utils.py", line 544, in save
    xm.save(obj, f)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 863, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 819, in save
    rendezvous('torch_xla.core.xla_model.save')
  File "/usr/local/lib/python3.7/dist-packages/torch_xla/core/xla_model.py", line 863, in rendezvous
    return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:364 : Failed to meet rendezvous 'torch_xla.core.xla_model.save': Socket closed (14)
---------------------------------------------------------------------------
ProcessExitedException                    Traceback (most recent call last)
<ipython-input-85-a91f3c0bb4fd> in <module>()
      1 from accelerate import notebook_launcher
      2 
----> 3 notebook_launcher(training_function)

3 frames
/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py in join(self, timeout)
    142                     error_index=error_index,
    143                     error_pid=failed_process.pid,
--> 144                     exit_code=exitcode
    145                 )
    146 

ProcessExitedException: process 0 terminated with exit code 17

As the error indicates, you are trying to save in a directory, and not a file.

Ah ok, I was expecting similar behaviour as trainer.save_model(f’./results/’) where you save the model (and tokenizer etc) to a directory. And as a less experience user the doc string for the .save method sounded like a directory path was required.

For future people who have the same issue: you need to specify a file path like this, where the file name ends with .pth:

training_directory_file = "nli-few-shot/TPU/model_tpu_saved.pth"
accelerator.save(unwrapped_model.state_dict(), f'./results/{training_directory_file}')