Save accelerate model

Hi everyone!

I’ve been working with the Hugging Face Trainer and encountered some challenges related to saving and loading models when using Fully Sharded Data Parallel (FSDP). Here’s the context:

My Setup:

  • I’m using HF Trainer without any explicit accelerate-related code, as I understand HF Trainer utilizes accelerate automatically under the hood.
  • I run my script with the following command:
accelerate launch --config_file CONFIG_FILE_PATH my_script.py
  • Depending on the configuration file, I’m toggling between DeepSpeed and FSDP.

Observations:

  1. DeepSpeed Configuration
    Using the DeepSpeed config, everything works smoothly. The saved model can be reloaded later via:
AutoModelForCausalLM.from_pretrained(saved_model_dir)
  1. FSDP Configuration
    When I switch to the FSDP config, the saved directory contains:
rng_state_0.pth  
rng_state_1.pth  
scheduler.pt  
trainer_state.json  

However, it does not produce the expected model files compatible with AutoModelForCausalLM.from_pretrained.


Questions:

  1. Do I need to explicitly call accelerator.unwrap_model to save the model correctly with FSDP?
  • If so, why does it work automatically with the DeepSpeed config without needing this step?
  1. How can I retrieve the properly wrapped model from the HF Trainer?
  • Is the model being wrapped behind the scenes, and how can I access it in this case?
  1. Could someone point me to documentation or an example showcasing how to use HF Trainer with accelerate and FSDP in a way that ensures the model is saved and reloadable correctly?
  2. Why there are some warning in console like the following while training
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:732: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:744: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:746: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:747: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py:108: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py:200: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(


Environment

Python 3.11.10

pip list
Package                           Version
--------------------------------- --------------
accelerate                        1.2.1
aiofiles                          23.2.1
aiohappyeyeballs                  2.4.4
aiohttp                           3.11.10
aiosignal                         1.3.2
annotated-types                   0.7.0
anyio                             4.6.2.post1
argon2-cffi                       23.1.0
argon2-cffi-bindings              21.2.0
arrow                             1.3.0
asttokens                         2.4.1
async-lru                         2.0.4
asyncio                           3.4.3
attrs                             24.2.0
babel                             2.16.0
beautifulsoup4                    4.12.3
bleach                            6.2.0
blinker                           1.4
certifi                           2024.8.30
cffi                              1.17.1
charset-normalizer                3.4.0
click                             8.1.7
comm                              0.2.2
cryptography                      3.4.8
datasets                          3.2.0
dbus-python                       1.2.18
debugpy                           1.8.8
decorator                         5.1.1
deepspeed                         0.16.1
defusedxml                        0.7.1
dill                              0.3.8
distro                            1.7.0
docker-pycreds                    0.4.0
einops                            0.8.0
entrypoints                       0.4
executing                         2.1.0
ezdxf                             1.3.5
fastapi                           0.115.6
fastjsonschema                    2.20.0
ffmpy                             0.4.0
filelock                          3.13.1
fire                              0.7.0
fonttools                         4.55.3
fqdn                              1.5.1
frozenlist                        1.5.0
fsspec                            2024.2.0
gitdb                             4.0.11
GitPython                         3.1.43
gradio                            5.9.1
gradio_client                     1.5.2
h11                               0.14.0
hjson                             3.1.0
httpcore                          1.0.6
httplib2                          0.20.2
httpx                             0.27.2
huggingface-hub                   0.27.0
idna                              3.10
importlib-metadata                4.6.4
ipykernel                         6.29.5
ipython                           8.29.0
ipython-genutils                  0.2.0
ipywidgets                        8.1.5
isoduration                       20.11.0
jedi                              0.19.2
jeepney                           0.7.1
Jinja2                            3.1.3
joblib                            1.4.2
json5                             0.9.28
jsonpointer                       3.0.0
jsonschema                        4.23.0
jsonschema-specifications         2024.10.1
jupyter-archive                   3.4.0
jupyter_client                    7.4.9
jupyter_contrib_core              0.4.2
jupyter_contrib_nbextensions      0.7.0
jupyter_core                      5.7.2
jupyter-events                    0.10.0
jupyter-highlight-selected-word   0.2.0
jupyter-lsp                       2.2.5
jupyter_nbextensions_configurator 0.6.4
jupyter_server                    2.14.2
jupyter_server_terminals          0.5.3
jupyterlab                        4.2.5
jupyterlab_pygments               0.3.0
jupyterlab_server                 2.27.3
jupyterlab_widgets                3.0.13
keyring                           23.5.0
launchpadlib                      1.10.16
lazr.restfulclient                0.14.4
lazr.uri                          1.0.6
lxml                              5.3.0
markdown-it-py                    3.0.0
MarkupSafe                        2.1.5
matplotlib-inline                 0.1.7
mdurl                             0.1.2
mistune                           3.0.2
more-itertools                    8.10.0
mpmath                            1.3.0
msgpack                           1.1.0
multidict                         6.1.0
multiprocess                      0.70.16
nbclassic                         1.1.0
nbclient                          0.10.0
nbconvert                         7.16.4
nbformat                          5.10.4
nest-asyncio                      1.6.0
networkx                          3.2.1
ninja                             1.11.1.3
notebook                          6.5.5
notebook_shim                     0.2.4
numpy                             1.26.3
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-ml-py                      12.560.30
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
oauthlib                          3.2.0
orjson                            3.10.12
overrides                         7.7.0
packaging                         24.2
pandas                            2.2.3
pandocfilters                     1.5.1
parso                             0.8.4
peft                              0.14.0
pexpect                           4.9.0
pillow                            10.2.0
pip                               24.3.1
platformdirs                      4.3.6
prometheus_client                 0.21.0
prompt_toolkit                    3.0.48
propcache                         0.2.1
protobuf                          5.29.1
psutil                            6.1.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
pyarrow                           18.1.0
pycparser                         2.22
pydantic                          2.10.3
pydantic_core                     2.27.1
pydub                             0.25.1
Pygments                          2.18.0
PyGObject                         3.42.1
PyJWT                             2.3.0
pyparsing                         2.4.7
python-apt                        2.4.0+ubuntu4
python-dateutil                   2.9.0.post0
python-json-logger                2.0.7
python-multipart                  0.0.20
pytz                              2024.2
PyYAML                            6.0.2
pyzmq                             24.0.1
referencing                       0.35.1
regex                             2024.11.6
requests                          2.32.3
rfc3339-validator                 0.1.4
rfc3986-validator                 0.1.1
rich                              13.9.4
rpds-py                           0.21.0
ruff                              0.8.3
safehttpx                         0.1.6
safetensors                       0.4.5
scikit-learn                      1.6.0
scipy                             1.14.1
SecretStorage                     3.3.1
semantic-version                  2.10.0
Send2Trash                        1.8.3
sentry-sdk                        2.19.2
setproctitle                      1.3.4
setuptools                        75.4.0
shellingham                       1.5.4
six                               1.16.0
smmap                             5.0.1
sniffio                           1.3.1
soupsieve                         2.6
stack-data                        0.6.3
starlette                         0.41.3
svgwrite                          1.4.3
sympy                             1.13.1
termcolor                         2.5.0
terminado                         0.18.1
threadpoolctl                     3.5.0
tinycss2                          1.4.0
tokenizers                        0.21.0
tomlkit                           0.13.2
torch                             2.5.1+cu124
torchaudio                        2.5.1+cu124
torchvision                       0.20.1+cu124
tornado                           6.4.1
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.47.1
triton                            3.1.0
typer                             0.15.1
types-python-dateutil             2.9.0.20241003
typing_extensions                 4.12.2
tzdata                            2024.2
uri-template                      1.3.0
urllib3                           2.2.3
uvicorn                           0.34.0
wadllib                           1.3.6
wandb                             0.19.1
wcwidth                           0.2.13
webcolors                         24.11.1
webencodings                      0.5.1
websocket-client                  1.8.0
websockets                        14.1
wheel                             0.45.0
widgetsnbextension                4.0.13
xxhash                            3.5.0
yarl                              1.18.3
zipp                              1.0.0
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:18:05_PDT_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0
nvidia-smi
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A40                     On  |   00000000:52:00.0 Off |                    0 |
|  0%   72C    P0            266W /  300W |   24277MiB /  46068MiB |     99%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     On  |   00000000:57:00.0 Off |                    0 |
|  0%   73C    P0            253W /  300W |   41029MiB /  46068MiB |     96%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

±----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
±----------------------------------------------------------------------------------------+


Any guidance or suggestions would be greatly appreciated. Thanks in advance! :hugs:

2 Likes

UP :rocket: :rocket: :rocket: :rocket: :rocket:

1 Like

I have arrived at this same issue. Have you found a workaround. Using Accelerate with SFT and training is completing but saving is proving difficult

1 Like

This issue?

I’m not a expert here, but IIUC, this is expected behaviour:

Reason for existence of SHARDED checkpoints
Original issue: saving a large model with FSDP takes time, because it requires all model parameters be passed to the same process and be put on CPU. This is extremely time-consuming and not the be use of GPU while doing training.

Solution: SHARDED_STATE_DICT doesn’t gather all the tensors in one process and each GPU stores its part of the model.

This makes things faster (good for checkpointing), but harder to load. In theory, trainer should be able to load these checkpoints without extra tuning (though I haven’t tried it myself).

Usually though, at the end of training you would want to save the model with FULL_STATE_DICT at least once so that the final models are in easily usable format.

For this, I’ve seen people change the format after training and before the final model save:

...
trainer.train(..)
if using_fsdp:
    trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
trainer.save_model(...)

btw, if you have the same number of GPUs as what you trained with originally, you should be able to load a model saved with SHARDED_STATE_DICT like so:

import torch
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.api import StateDictType

# Assume you have a model wrapped in FSDP
model = FSDP(...) 

# Load the sharded state dict
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
    state_dict = torch.load("path/to/sharded_checkpoint.pt")
    model.load_state_dict(state_dict)
1 Like