Hello,
I have trained a model using FSDP. More specifically, I used this run_clm.py script, with the option --fsdp "shard_grad_op auto_wrap"
.
The training went fine and model was saved. However, while trying to load the model I get error:
Loading checkpoint shards: 100%|███████████████████████████████████████| 3/3 [00:00<00:00, 19.55it/s]
Traceback (most recent call last):
File "/home/tarun/memory-llm-paper/run_experiments_pythia.py", line 49, in <module>
model = AutoModelForCausalLM.from_pretrained('/data/users/tarun/coref/models/output32/checkpoint-345')
File "/home/tarun/miniconda3/envs/coref-1/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
return model_class.from_pretrained(
File "/home/tarun/miniconda3/envs/coref-1/lib/python3.9/site-packages/transformers/modeling_utils.py", line 4014, in from_pretrained
) = cls._load_pretrained_model(
File "/home/tarun/miniconda3/envs/coref-1/lib/python3.9/site-packages/transformers/modeling_utils.py", line 4559, in _load_pretrained_model
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for GPTNeoXForCausalLM:
size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([128778240]) from checkpoint, the shape in current model is torch.Size([50304, 2560]).
size mismatch for gpt_neox.final_layer_norm.bias: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([2560]).
size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 2560]).
You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
Now, if we specifically look at this line:
size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([128778240]) from checkpoint, the shape in current model is torch.Size([50304, 2560]).
we can notice that 50304*2560 = 128778240. So it seems that at the end of FSDP training, the model’s params were stored in a flattened array. While trying to load it using from_pretrained
the library isn’t able to unflatten it automatically.
This is my conda environment details if that helps:
Conda env details
name: coref-1
channels:
- pytorch
- plotly
- nvidia
- conda-forge
- defaults
dependencies: - _libgcc_mutex=0.1=conda_forge
- _openmp_mutex=4.5=2_gnu
- _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10
- asttokens=2.4.1=pyhd8ed1ab_0
- attrs=23.1.0=py39h06a4308_0
- binutils_impl_linux-64=2.38=h2a08ee3_1
- binutils_linux-64=2.38.0=hc2dff05_0
- blas=1.0=openblas
- bzip2=1.0.8=h4bc722e_7
- c-ares=1.33.1=heb4867d_0
- ca-certificates=2024.7.4=hbcca054_0
- cmake=3.30.2=hf8c4bd3_0
- comm=0.2.2=pyhd8ed1ab_0
- cuda-cudart=12.4.127=0
- cuda-cupti=12.4.127=0
- cuda-libraries=12.4.0=0
- cuda-nvrtc=12.4.127=0
- cuda-nvtx=12.4.127=0
- cuda-opencl=12.4.127=0
- cuda-runtime=12.4.0=0
- cuda-version=11.8=hcce14f8_3
- cudatoolkit=11.8.0=h6a678d5_0
- cudnn=8.9.2.26=cuda11_0
- cupti=11.8.0=he078b1a_0
- debugpy=1.8.5=py39h98e3656_0
- decorator=5.1.1=pyhd8ed1ab_0
- exceptiongroup=1.2.2=pyhd8ed1ab_0
- executing=2.0.1=pyhd8ed1ab_0
- filelock=3.13.1=py39h06a4308_0
- fsspec=2024.6.1=py39h06a4308_0
- gcc_impl_linux-64=11.2.0=h1234567_1
- gcc_linux-64=11.2.0=h5c386dc_0
- gh=2.55.0=h76a2195_0
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py39heeb90bb_0
- gxx_impl_linux-64=11.2.0=h1234567_1
- gxx_linux-64=11.2.0=hc2dff05_0
- importlib-metadata=8.4.0=pyha770c72_0
- importlib_metadata=8.4.0=hd8ed1ab_0
- intel-openmp=2022.0.1=h06a4308_3633
- ipykernel=6.29.5=pyh3099207_0
- ipython=8.18.1=pyh707e725_3
- ipywidgets=8.1.5=pyhd8ed1ab_0
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.4=py39h06a4308_0
- jsonschema-specifications=2023.7.1=py39h06a4308_0
- jupyter_client=8.6.2=pyhd8ed1ab_0
- jupyter_core=5.7.2=py39hf3d152e_0
- jupyterlab_widgets=3.0.13=pyhd8ed1ab_0
- kernel-headers_linux-64=3.10.0=h57e8cba_10
- keyutils=1.6.1=h166bdaf_0
- krb5=1.21.3=h659f571_0
- ld_impl_linux-64=2.38=h1181459_1
- libabseil=20240116.2=cxx17_h6a678d5_0
- libblas=3.9.0=16_linux64_openblas
- libcublas=12.4.2.65=0
- libcufft=11.2.0.44=0
- libcufile=1.9.1.3=0
- libcurand=10.3.5.147=0
- libcurl=8.9.1=hdb1bdb2_0
- libcusolver=11.6.0.99=0
- libcusparse=12.3.0.142=0
- libedit=3.1.20191231=he28a2e2_2
- libev=4.33=hd590300_2
- libexpat=2.6.2=h59595ed_0
- libffi=3.4.4=h6a678d5_1
- libgcc-devel_linux-64=11.2.0=h1234567_1
- libgcc-ng=14.1.0=h77fa898_0
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=14.1.0=h77fa898_0
- liblapack=3.9.0=16_linux64_openblas
- libmagma=2.8.0=hfdb99dd_0
- libmagma_sparse=2.8.0=h9ddd185_0
- libnghttp2=1.58.0=h47da74e_1
- libnpp=12.2.5.2=0
- libnsl=2.0.1=hd590300_0
- libnvfatbin=12.4.127=0
- libnvjitlink=12.4.99=0
- libnvjpeg=12.3.1.89=0
- libopenblas=0.3.21=h043d6bf_0
- libprotobuf=4.25.3=h08a7969_0
- libsodium=1.0.18=h36c2ea0_1
- libsqlite=3.46.0=hde9e2c9_0
- libssh2=1.11.0=h0841786_0
- libstdcxx-devel_linux-64=11.2.0=h1234567_1
- libstdcxx-ng=14.1.0=hc0a3c3a_0
- libuuid=2.38.1=h0b41bf4_0
- libuv=1.48.0=hd590300_0
- libxcrypt=4.4.36=hd590300_1
- libzlib=1.3.1=h4ab18f5_1
- llvm-openmp=15.0.7=h0cdce71_0
- magma=2.8.0=h4aca40b_0
- markupsafe=2.1.3=py39h5eee18b_0
- matplotlib-inline=0.1.7=pyhd8ed1ab_0
- mkl=2022.1.0=hc2b9512_224
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py39h06a4308_0
- nbformat=5.9.2=py39h06a4308_0
- ncurses=6.5=he02047a_1
- nest-asyncio=1.6.0=pyhd8ed1ab_0
- networkx=3.2.1=py39h06a4308_0
- numpy=1.26.4=py39heeff2f4_0
- numpy-base=1.26.4=py39h8a23956_0
- openssl=3.3.1=hb9d3cd8_3
- packaging=24.1=pyhd8ed1ab_0
- parso=0.8.4=pyhd8ed1ab_0
- pexpect=4.9.0=pyhd8ed1ab_0
- pickleshare=0.7.5=py_1003
- pip=24.2=py39h06a4308_0
- platformdirs=4.2.2=pyhd8ed1ab_0
- plotly=5.23.0=py_0
- prompt-toolkit=3.0.47=pyha770c72_0
- psutil=6.0.0=py39hd3abc70_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.3=pyhd8ed1ab_0
- pybind11-abi=4=hd3eb1b0_1
- pygments=2.18.0=pyhd8ed1ab_0
- python=3.9.19=h0755675_0_cpython
- python-dateutil=2.9.0=pyhd8ed1ab_0
- python-fastjsonschema=2.16.2=py39h06a4308_0
- python_abi=3.9=5_cp39
- pytorch=2.3.0=gpu_cuda118py39h796af20_101
- pytorch-cuda=12.4=hc786d27_6
- pytorch-mutex=1.0=cuda
- pyyaml=6.0.1=py39h5eee18b_0
- pyzmq=26.2.0=py39h4e4fb57_0
- readline=8.2=h5eee18b_0
- rhash=1.4.4=hd590300_0
- rpds-py=0.10.6=py39hb02cf49_0
- scipy=1.13.1=py39heeff2f4_0
- setuptools=72.1.0=py39h06a4308_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.46.0=h6d4b2fc_0
- stack_data=0.6.2=pyhd8ed1ab_0
- sympy=1.12=py39h06a4308_0
- sysroot_linux-64=2.17=h57e8cba_10
- tenacity=8.2.3=py39h06a4308_0
- tk=8.6.13=noxft_h4845f30_101
- torchtriton=2.3.0=cuda123py39hdb19cb5_0
- tornado=6.4.1=py39hd3abc70_0
- traitlets=5.14.3=pyhd8ed1ab_0
- typing_extensions=4.11.0=py39h06a4308_0
- wcwidth=0.2.13=pyhd8ed1ab_0
- wheel=0.43.0=py39h06a4308_0
- widgetsnbextension=4.0.13=pyhd8ed1ab_0
- xz=5.4.6=h5eee18b_1
- yaml=0.2.5=h7b6447c_0
- zeromq=4.3.5=h75354e8_4
- zipp=3.20.1=pyhd8ed1ab_0
- zlib=1.3.1=h4ab18f5_1
- zstd=1.5.6=ha6fb4c9_0
- pip:
- accelerate==0.34.2
- aiohappyeyeballs==2.4.2
- aiohttp==3.10.8
- aiosignal==1.3.1
- async-timeout==4.0.3
- certifi==2024.8.30
- charset-normalizer==3.3.2
- contourpy==1.3.0
- cycler==0.12.1
- datasets==3.0.1
- dill==0.3.8
- evaluate==0.4.3
- fonttools==4.54.1
- frozenlist==1.4.1
- huggingface-hub==0.25.1
- idna==3.10
- importlib-resources==6.4.5
- joblib==1.4.2
- kiwisolver==1.4.7
- matplotlib==3.9.2
- multidict==6.1.0
- multiprocess==0.70.16
- pandas==2.2.3
- peft==0.13.0
- pillow==10.4.0
- protobuf==5.28.2
- pyarrow==17.0.0
- pyparsing==3.1.4
- pytz==2024.2
- regex==2024.9.11
- requests==2.32.3
- safetensors==0.4.5
- scikit-learn==1.5.2
- sentencepiece==0.2.0
- threadpoolctl==3.5.0
- tokenizers==0.20.0
- tqdm==4.66.5
- transformers==4.46.0.dev0
- tzdata==2024.2
- urllib3==2.2.3
- xxhash==3.5.0
- yarl==1.13.1
prefix: /home/tarun/miniconda3/envs/coref-1
Does anyone have any idea what could be causing this? I didn’t find much on surfing the internet.