Hi everyone!
I’ve been working with the Hugging Face Trainer and encountered some challenges related to saving and loading models when using Fully Sharded Data Parallel (FSDP). Here’s the context:
My Setup:
- I’m using HF Trainer without any explicit
accelerate
-related code, as I understand HF Trainer utilizesaccelerate
automatically under the hood. - I run my script with the following command:
accelerate launch --config_file CONFIG_FILE_PATH my_script.py
- Depending on the configuration file, I’m toggling between DeepSpeed and FSDP.
Observations:
- DeepSpeed Configuration
Using the DeepSpeed config, everything works smoothly. The saved model can be reloaded later via:
AutoModelForCausalLM.from_pretrained(saved_model_dir)
- FSDP Configuration
When I switch to the FSDP config, the saved directory contains:
rng_state_0.pth
rng_state_1.pth
scheduler.pt
trainer_state.json
However, it does not produce the expected model files compatible with AutoModelForCausalLM.from_pretrained
.
Questions:
- Do I need to explicitly call
accelerator.unwrap_model
to save the model correctly with FSDP?
- If so, why does it work automatically with the DeepSpeed config without needing this step?
- How can I retrieve the properly wrapped model from the HF Trainer?
- Is the model being wrapped behind the scenes, and how can I access it in this case?
- Could someone point me to documentation or an example showcasing how to use HF Trainer with
accelerate
and FSDP in a way that ensures the model is saved and reloadable correctly? - Why there are some warning in console like the following while training
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
warnings.warn(
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:732: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
local_shape = tensor.shape
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:744: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.shape,
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:746: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.dtype,
/usr/local/lib/python3.11/dist-packages/torch/distributed/fsdp/_state_dict_utils.py:747: FutureWarning: Please use DTensor instead and we are deprecating ShardedTensor.
tensor.device,
/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py:108: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
dist_cp.save_state_dict(
/usr/local/lib/python3.11/dist-packages/accelerate/utils/fsdp_utils.py:200: FutureWarning: `save_state_dict` is deprecated and will be removed in future versions.Please use `save` instead.
dist_cp.save_state_dict(
Environment
Python 3.11.10
pip list
Package Version --------------------------------- -------------- accelerate 1.2.1 aiofiles 23.2.1 aiohappyeyeballs 2.4.4 aiohttp 3.11.10 aiosignal 1.3.2 annotated-types 0.7.0 anyio 4.6.2.post1 argon2-cffi 23.1.0 argon2-cffi-bindings 21.2.0 arrow 1.3.0 asttokens 2.4.1 async-lru 2.0.4 asyncio 3.4.3 attrs 24.2.0 babel 2.16.0 beautifulsoup4 4.12.3 bleach 6.2.0 blinker 1.4 certifi 2024.8.30 cffi 1.17.1 charset-normalizer 3.4.0 click 8.1.7 comm 0.2.2 cryptography 3.4.8 datasets 3.2.0 dbus-python 1.2.18 debugpy 1.8.8 decorator 5.1.1 deepspeed 0.16.1 defusedxml 0.7.1 dill 0.3.8 distro 1.7.0 docker-pycreds 0.4.0 einops 0.8.0 entrypoints 0.4 executing 2.1.0 ezdxf 1.3.5 fastapi 0.115.6 fastjsonschema 2.20.0 ffmpy 0.4.0 filelock 3.13.1 fire 0.7.0 fonttools 4.55.3 fqdn 1.5.1 frozenlist 1.5.0 fsspec 2024.2.0 gitdb 4.0.11 GitPython 3.1.43 gradio 5.9.1 gradio_client 1.5.2 h11 0.14.0 hjson 3.1.0 httpcore 1.0.6 httplib2 0.20.2 httpx 0.27.2 huggingface-hub 0.27.0 idna 3.10 importlib-metadata 4.6.4 ipykernel 6.29.5 ipython 8.29.0 ipython-genutils 0.2.0 ipywidgets 8.1.5 isoduration 20.11.0 jedi 0.19.2 jeepney 0.7.1 Jinja2 3.1.3 joblib 1.4.2 json5 0.9.28 jsonpointer 3.0.0 jsonschema 4.23.0 jsonschema-specifications 2024.10.1 jupyter-archive 3.4.0 jupyter_client 7.4.9 jupyter_contrib_core 0.4.2 jupyter_contrib_nbextensions 0.7.0 jupyter_core 5.7.2 jupyter-events 0.10.0 jupyter-highlight-selected-word 0.2.0 jupyter-lsp 2.2.5 jupyter_nbextensions_configurator 0.6.4 jupyter_server 2.14.2 jupyter_server_terminals 0.5.3 jupyterlab 4.2.5 jupyterlab_pygments 0.3.0 jupyterlab_server 2.27.3 jupyterlab_widgets 3.0.13 keyring 23.5.0 launchpadlib 1.10.16 lazr.restfulclient 0.14.4 lazr.uri 1.0.6 lxml 5.3.0 markdown-it-py 3.0.0 MarkupSafe 2.1.5 matplotlib-inline 0.1.7 mdurl 0.1.2 mistune 3.0.2 more-itertools 8.10.0 mpmath 1.3.0 msgpack 1.1.0 multidict 6.1.0 multiprocess 0.70.16 nbclassic 1.1.0 nbclient 0.10.0 nbconvert 7.16.4 nbformat 5.10.4 nest-asyncio 1.6.0 networkx 3.2.1 ninja 1.11.1.3 notebook 6.5.5 notebook_shim 0.2.4 numpy 1.26.3 nvidia-cublas-cu12 12.4.5.8 nvidia-cuda-cupti-cu12 12.4.127 nvidia-cuda-nvrtc-cu12 12.4.127 nvidia-cuda-runtime-cu12 12.4.127 nvidia-cudnn-cu12 9.1.0.70 nvidia-cufft-cu12 11.2.1.3 nvidia-curand-cu12 10.3.5.147 nvidia-cusolver-cu12 11.6.1.9 nvidia-cusparse-cu12 12.3.1.170 nvidia-ml-py 12.560.30 nvidia-nccl-cu12 2.21.5 nvidia-nvjitlink-cu12 12.4.127 nvidia-nvtx-cu12 12.4.127 oauthlib 3.2.0 orjson 3.10.12 overrides 7.7.0 packaging 24.2 pandas 2.2.3 pandocfilters 1.5.1 parso 0.8.4 peft 0.14.0 pexpect 4.9.0 pillow 10.2.0 pip 24.3.1 platformdirs 4.3.6 prometheus_client 0.21.0 prompt_toolkit 3.0.48 propcache 0.2.1 protobuf 5.29.1 psutil 6.1.0 ptyprocess 0.7.0 pure_eval 0.2.3 py-cpuinfo 9.0.0 pyarrow 18.1.0 pycparser 2.22 pydantic 2.10.3 pydantic_core 2.27.1 pydub 0.25.1 Pygments 2.18.0 PyGObject 3.42.1 PyJWT 2.3.0 pyparsing 2.4.7 python-apt 2.4.0+ubuntu4 python-dateutil 2.9.0.post0 python-json-logger 2.0.7 python-multipart 0.0.20 pytz 2024.2 PyYAML 6.0.2 pyzmq 24.0.1 referencing 0.35.1 regex 2024.11.6 requests 2.32.3 rfc3339-validator 0.1.4 rfc3986-validator 0.1.1 rich 13.9.4 rpds-py 0.21.0 ruff 0.8.3 safehttpx 0.1.6 safetensors 0.4.5 scikit-learn 1.6.0 scipy 1.14.1 SecretStorage 3.3.1 semantic-version 2.10.0 Send2Trash 1.8.3 sentry-sdk 2.19.2 setproctitle 1.3.4 setuptools 75.4.0 shellingham 1.5.4 six 1.16.0 smmap 5.0.1 sniffio 1.3.1 soupsieve 2.6 stack-data 0.6.3 starlette 0.41.3 svgwrite 1.4.3 sympy 1.13.1 termcolor 2.5.0 terminado 0.18.1 threadpoolctl 3.5.0 tinycss2 1.4.0 tokenizers 0.21.0 tomlkit 0.13.2 torch 2.5.1+cu124 torchaudio 2.5.1+cu124 torchvision 0.20.1+cu124 tornado 6.4.1 tqdm 4.67.1 traitlets 5.14.3 transformers 4.47.1 triton 3.1.0 typer 0.15.1 types-python-dateutil 2.9.0.20241003 typing_extensions 4.12.2 tzdata 2024.2 uri-template 1.3.0 urllib3 2.2.3 uvicorn 0.34.0 wadllib 1.3.6 wandb 0.19.1 wcwidth 0.2.13 webcolors 24.11.1 webencodings 0.5.1 websocket-client 1.8.0 websockets 14.1 wheel 0.45.0 widgetsnbextension 4.0.13 xxhash 3.5.0 yarl 1.18.3 zipp 1.0.0
nvcc --version
nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2024 NVIDIA Corporation Built on Thu_Sep_12_02:18:05_PDT_2024 Cuda compilation tools, release 12.6, V12.6.77 Build cuda_12.6.r12.6/compiler.34841621_0
nvidia-smi
+-----------------------------------------------------------------------------------------+ | NVIDIA-SMI 550.127.05 Driver Version: 550.127.05 CUDA Version: 12.6 | |-----------------------------------------+------------------------+----------------------+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | | | MIG M. | |=========================================+========================+======================| | 0 NVIDIA A40 On | 00000000:52:00.0 Off | 0 | | 0% 72C P0 266W / 300W | 24277MiB / 46068MiB | 99% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+ | 1 NVIDIA A40 On | 00000000:57:00.0 Off | 0 | | 0% 73C P0 253W / 300W | 41029MiB / 46068MiB | 96% Default | | | | N/A | +-----------------------------------------+------------------------+----------------------+±----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
±----------------------------------------------------------------------------------------+
Any guidance or suggestions would be greatly appreciated. Thanks in advance!