Save accelerate model

Hi everyone!

I’ve been working with the Hugging Face Trainer and encountered some challenges related to saving and loading models when using Fully Sharded Data Parallel (FSDP). Here’s the context:

My Setup:

  • I’m using HF Trainer without any explicit accelerate-related code, as I understand HF Trainer utilizes accelerate automatically under the hood.
  • I run my script with the following command:
accelerate launch --config_file CONFIG_FILE_PATH my_script.py
  • Depending on the configuration file, I’m toggling between DeepSpeed and FSDP.

Observations:

  1. DeepSpeed Configuration
    Using the DeepSpeed config, everything works smoothly. The saved model can be reloaded later via:
AutoModelForCausalLM.from_pretrained(saved_model_dir)
  1. FSDP Configuration
    When I switch to the FSDP config, the saved directory contains:
rng_state_0.pth  
rng_state_1.pth  
scheduler.pt  
trainer_state.json  

However, it does not produce the expected model files compatible with AutoModelForCausalLM.from_pretrained.


Questions:

  1. Do I need to explicitly call accelerator.unwrap_model to save the model correctly with FSDP?
  • If so, why does it work automatically with the DeepSpeed config without needing this step?
  1. How can I retrieve the properly wrapped model from the HF Trainer?
  • Is the model being wrapped behind the scenes, and how can I access it in this case?
  1. Could someone point me to documentation or an example showcasing how to use HF Trainer with accelerate and FSDP in a way that ensures the model is saved and reloadable correctly?
  2. Why there are some warning in console like the following while training
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
  warnings.warn(
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:732: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  local_shape = tensor.shape
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:744: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.shape,
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:746: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.dtype,
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:747: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
  tensor.device,
/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py:108: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(
/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py:200: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
  dist_cp.save_state_dict(


Environment

Python 3.11.10

pip list
Package                           Version
--------------------------------- --------------
accelerate                        1.2.1
aiofiles                          23.2.1
aiohappyeyeballs                  2.4.4
aiohttp                           3.11.10
aiosignal                         1.3.2
annotated-types                   0.7.0
anyio                             4.6.2.post1
argon2-cffi                       23.1.0
argon2-cffi-bindings              21.2.0
arrow                             1.3.0
asttokens                         2.4.1
async-lru                         2.0.4
asyncio                           3.4.3
attrs                             24.2.0
babel                             2.16.0
beautifulsoup4                    4.12.3
bleach                            6.2.0
blinker                           1.4
certifi                           2024.8.30
cffi                              1.17.1
charset-normalizer                3.4.0
click                             8.1.7
comm                              0.2.2
cryptography                      3.4.8
datasets                          3.2.0
dbus-python                       1.2.18
debugpy                           1.8.8
decorator                         5.1.1
deepspeed                         0.16.1
defusedxml                        0.7.1
dill                              0.3.8
distro                            1.7.0
docker-pycreds                    0.4.0
einops                            0.8.0
entrypoints                       0.4
executing                         2.1.0
ezdxf                             1.3.5
fastapi                           0.115.6
fastjsonschema                    2.20.0
ffmpy                             0.4.0
filelock                          3.13.1
fire                              0.7.0
fonttools                         4.55.3
fqdn                              1.5.1
frozenlist                        1.5.0
fsspec                            2024.2.0
gitdb                             4.0.11
GitPython                         3.1.43
gradio                            5.9.1
gradio_client                     1.5.2
h11                               0.14.0
hjson                             3.1.0
httpcore                          1.0.6
httplib2                          0.20.2
httpx                             0.27.2
huggingface-hub                   0.27.0
idna                              3.10
importlib-metadata                4.6.4
ipykernel                         6.29.5
ipython                           8.29.0
ipython-genutils                  0.2.0
ipywidgets                        8.1.5
isoduration                       20.11.0
jedi                              0.19.2
jeepney                           0.7.1
Jinja2                            3.1.3
joblib                            1.4.2
json5                             0.9.28
jsonpointer                       3.0.0
jsonschema                        4.23.0
jsonschema-specifications         2024.10.1
jupyter-archive                   3.4.0
jupyter_client                    7.4.9
jupyter_contrib_core              0.4.2
jupyter_contrib_nbextensions      0.7.0
jupyter_core                      5.7.2
jupyter-events                    0.10.0
jupyter-highlight-selected-word   0.2.0
jupyter-lsp                       2.2.5
jupyter_nbextensions_configurator 0.6.4
jupyter_server                    2.14.2
jupyter_server_terminals          0.5.3
jupyterlab                        4.2.5
jupyterlab_pygments               0.3.0
jupyterlab_server                 2.27.3
jupyterlab_widgets                3.0.13
keyring                           23.5.0
launchpadlib                      1.10.16
lazr.restfulclient                0.14.4
lazr.uri                          1.0.6
lxml                              5.3.0
markdown-it-py                    3.0.0
MarkupSafe                        2.1.5
matplotlib-inline                 0.1.7
mdurl                             0.1.2
mistune                           3.0.2
more-itertools                    8.10.0
mpmath                            1.3.0
msgpack                           1.1.0
multidict                         6.1.0
multiprocess                      0.70.16
nbclassic                         1.1.0
nbclient                          0.10.0
nbconvert                         7.16.4
nbformat                          5.10.4
nest-asyncio                      1.6.0
networkx                          3.2.1
ninja                             1.11.1.3
notebook                          6.5.5
notebook_shim                     0.2.4
numpy                             1.26.3
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-ml-py                      12.560.30
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
oauthlib                          3.2.0
orjson                            3.10.12
overrides                         7.7.0
packaging                         24.2
pandas                            2.2.3
pandocfilters                     1.5.1
parso                             0.8.4
peft                              0.14.0
pexpect                           4.9.0
pillow                            10.2.0
pip                               24.3.1
platformdirs                      4.3.6
prometheus_client                 0.21.0
prompt_toolkit                    3.0.48
propcache                         0.2.1
protobuf                          5.29.1
psutil                            6.1.0
ptyprocess                        0.7.0
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
pyarrow                           18.1.0
pycparser                         2.22
pydantic                          2.10.3
pydantic_core                     2.27.1
pydub                             0.25.1
Pygments                          2.18.0
PyGObject                         3.42.1
PyJWT                             2.3.0
pyparsing                         2.4.7
python-apt                        2.4.0+ubuntu4
python-dateutil                   2.9.0.post0
python-json-logger                2.0.7
python-multipart                  0.0.20
pytz                              2024.2
PyYAML                            6.0.2
pyzmq                             24.0.1
referencing                       0.35.1
regex                             2024.11.6
requests                          2.32.3
rfc3339-validator                 0.1.4
rfc3986-validator                 0.1.1
rich                              13.9.4
rpds-py                           0.21.0
ruff                              0.8.3
safehttpx                         0.1.6
safetensors                       0.4.5
scikit-learn                      1.6.0
scipy                             1.14.1
SecretStorage                     3.3.1
semantic-version                  2.10.0
Send2Trash                        1.8.3
sentry-sdk                        2.19.2
setproctitle                      1.3.4
setuptools                        75.4.0
shellingham                       1.5.4
six                               1.16.0
smmap                             5.0.1
sniffio                           1.3.1
soupsieve                         2.6
stack-data                        0.6.3
starlette                         0.41.3
svgwrite                          1.4.3
sympy                             1.13.1
termcolor                         2.5.0
terminado                         0.18.1
threadpoolctl                     3.5.0
tinycss2                          1.4.0
tokenizers                        0.21.0
tomlkit                           0.13.2
torch                             2.5.1+cu124
torchaudio                        2.5.1+cu124
torchvision                       0.20.1+cu124
tornado                           6.4.1
tqdm                              4.67.1
traitlets                         5.14.3
transformers                      4.47.1
triton                            3.1.0
typer                             0.15.1
types-python-dateutil             2.9.0.20241003
typing_extensions                 4.12.2
tzdata                            2024.2
uri-template                      1.3.0
urllib3                           2.2.3
uvicorn                           0.34.0
wadllib                           1.3.6
wandb                             0.19.1
wcwidth                           0.2.13
webcolors                         24.11.1
webencodings                      0.5.1
websocket-client                  1.8.0
websockets                        14.1
wheel                             0.45.0
widgetsnbextension                4.0.13
xxhash                            3.5.0
yarl                              1.18.3
zipp                              1.0.0
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2024 NVIDIA Corporation
Built on Thu_Sep_12_02:18:05_PDT_2024
Cuda compilation tools, release 12.6, V12.6.77
Build cuda_12.6.r12.6/compiler.34841621_0
nvidia-smi
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.05             Driver Version: 550.127.05     CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA A40                     On  |   00000000:52:00.0 Off |                    0 |
|  0%   72C    P0            266W /  300W |   24277MiB /  46068MiB |     99%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA A40                     On  |   00000000:57:00.0 Off |                    0 |
|  0%   73C    P0            253W /  300W |   41029MiB /  46068MiB |     96%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

±----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
±----------------------------------------------------------------------------------------+


Any guidance or suggestions would be greatly appreciated. Thanks in advance! :hugs:

2 Likes

UP :rocket: :rocket: :rocket: :rocket: :rocket: