Unable to load a model trained via FSDP

Hello,

I have trained a model using FSDP. More specifically, I used this run_clm.py script, with the option --fsdp "shard_grad_op auto_wrap".

The training went fine and model was saved. However, while trying to load the model I get error:

Loading checkpoint shards: 100%|███████████████████████████████████████| 3/3 [00:00<00:00, 19.55it/s]
Traceback (most recent call last):
  File "/home/tarun/memory-llm-paper/run_experiments_pythia.py", line 49, in <module>
    model = AutoModelForCausalLM.from_pretrained('/data/users/tarun/coref/models/output32/checkpoint-345')
  File "/home/tarun/miniconda3/envs/coref-1/lib/python3.9/site-packages/transformers/models/auto/auto_factory.py", line 564, in from_pretrained
    return model_class.from_pretrained(
  File "/home/tarun/miniconda3/envs/coref-1/lib/python3.9/site-packages/transformers/modeling_utils.py", line 4014, in from_pretrained
    ) = cls._load_pretrained_model(
  File "/home/tarun/miniconda3/envs/coref-1/lib/python3.9/site-packages/transformers/modeling_utils.py", line 4559, in _load_pretrained_model
    raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
RuntimeError: Error(s) in loading state_dict for GPTNeoXForCausalLM:
        size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([128778240]) from checkpoint, the shape in current model is torch.Size([50304, 2560]).
        size mismatch for gpt_neox.final_layer_norm.bias: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([2560]).
        size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint, the shape in current model is torch.Size([50304, 2560]).
        You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.

Now, if we specifically look at this line:

size mismatch for gpt_neox.embed_in.weight: copying a param with shape torch.Size([128778240]) from checkpoint, the shape in current model is torch.Size([50304, 2560]).

we can notice that 50304*2560 = 128778240. So it seems that at the end of FSDP training, the model’s params were stored in a flattened array. While trying to load it using from_pretrained the library isn’t able to unflatten it automatically.

This is my conda environment details if that helps:

Conda env details

name: coref-1
channels:

  • pytorch
  • plotly
  • nvidia
  • conda-forge
  • defaults
    dependencies:
  • _libgcc_mutex=0.1=conda_forge
  • _openmp_mutex=4.5=2_gnu
  • _sysroot_linux-64_curr_repodata_hack=3=haa98f57_10
  • asttokens=2.4.1=pyhd8ed1ab_0
  • attrs=23.1.0=py39h06a4308_0
  • binutils_impl_linux-64=2.38=h2a08ee3_1
  • binutils_linux-64=2.38.0=hc2dff05_0
  • blas=1.0=openblas
  • bzip2=1.0.8=h4bc722e_7
  • c-ares=1.33.1=heb4867d_0
  • ca-certificates=2024.7.4=hbcca054_0
  • cmake=3.30.2=hf8c4bd3_0
  • comm=0.2.2=pyhd8ed1ab_0
  • cuda-cudart=12.4.127=0
  • cuda-cupti=12.4.127=0
  • cuda-libraries=12.4.0=0
  • cuda-nvrtc=12.4.127=0
  • cuda-nvtx=12.4.127=0
  • cuda-opencl=12.4.127=0
  • cuda-runtime=12.4.0=0
  • cuda-version=11.8=hcce14f8_3
  • cudatoolkit=11.8.0=h6a678d5_0
  • cudnn=8.9.2.26=cuda11_0
  • cupti=11.8.0=he078b1a_0
  • debugpy=1.8.5=py39h98e3656_0
  • decorator=5.1.1=pyhd8ed1ab_0
  • exceptiongroup=1.2.2=pyhd8ed1ab_0
  • executing=2.0.1=pyhd8ed1ab_0
  • filelock=3.13.1=py39h06a4308_0
  • fsspec=2024.6.1=py39h06a4308_0
  • gcc_impl_linux-64=11.2.0=h1234567_1
  • gcc_linux-64=11.2.0=h5c386dc_0
  • gh=2.55.0=h76a2195_0
  • gmp=6.2.1=h295c915_3
  • gmpy2=2.1.2=py39heeb90bb_0
  • gxx_impl_linux-64=11.2.0=h1234567_1
  • gxx_linux-64=11.2.0=hc2dff05_0
  • importlib-metadata=8.4.0=pyha770c72_0
  • importlib_metadata=8.4.0=hd8ed1ab_0
  • intel-openmp=2022.0.1=h06a4308_3633
  • ipykernel=6.29.5=pyh3099207_0
  • ipython=8.18.1=pyh707e725_3
  • ipywidgets=8.1.5=pyhd8ed1ab_0
  • jedi=0.19.1=pyhd8ed1ab_0
  • jinja2=3.1.4=py39h06a4308_0
  • jsonschema-specifications=2023.7.1=py39h06a4308_0
  • jupyter_client=8.6.2=pyhd8ed1ab_0
  • jupyter_core=5.7.2=py39hf3d152e_0
  • jupyterlab_widgets=3.0.13=pyhd8ed1ab_0
  • kernel-headers_linux-64=3.10.0=h57e8cba_10
  • keyutils=1.6.1=h166bdaf_0
  • krb5=1.21.3=h659f571_0
  • ld_impl_linux-64=2.38=h1181459_1
  • libabseil=20240116.2=cxx17_h6a678d5_0
  • libblas=3.9.0=16_linux64_openblas
  • libcublas=12.4.2.65=0
  • libcufft=11.2.0.44=0
  • libcufile=1.9.1.3=0
  • libcurand=10.3.5.147=0
  • libcurl=8.9.1=hdb1bdb2_0
  • libcusolver=11.6.0.99=0
  • libcusparse=12.3.0.142=0
  • libedit=3.1.20191231=he28a2e2_2
  • libev=4.33=hd590300_2
  • libexpat=2.6.2=h59595ed_0
  • libffi=3.4.4=h6a678d5_1
  • libgcc-devel_linux-64=11.2.0=h1234567_1
  • libgcc-ng=14.1.0=h77fa898_0
  • libgfortran-ng=11.2.0=h00389a5_1
  • libgfortran5=11.2.0=h1234567_1
  • libgomp=14.1.0=h77fa898_0
  • liblapack=3.9.0=16_linux64_openblas
  • libmagma=2.8.0=hfdb99dd_0
  • libmagma_sparse=2.8.0=h9ddd185_0
  • libnghttp2=1.58.0=h47da74e_1
  • libnpp=12.2.5.2=0
  • libnsl=2.0.1=hd590300_0
  • libnvfatbin=12.4.127=0
  • libnvjitlink=12.4.99=0
  • libnvjpeg=12.3.1.89=0
  • libopenblas=0.3.21=h043d6bf_0
  • libprotobuf=4.25.3=h08a7969_0
  • libsodium=1.0.18=h36c2ea0_1
  • libsqlite=3.46.0=hde9e2c9_0
  • libssh2=1.11.0=h0841786_0
  • libstdcxx-devel_linux-64=11.2.0=h1234567_1
  • libstdcxx-ng=14.1.0=hc0a3c3a_0
  • libuuid=2.38.1=h0b41bf4_0
  • libuv=1.48.0=hd590300_0
  • libxcrypt=4.4.36=hd590300_1
  • libzlib=1.3.1=h4ab18f5_1
  • llvm-openmp=15.0.7=h0cdce71_0
  • magma=2.8.0=h4aca40b_0
  • markupsafe=2.1.3=py39h5eee18b_0
  • matplotlib-inline=0.1.7=pyhd8ed1ab_0
  • mkl=2022.1.0=hc2b9512_224
  • mpc=1.1.0=h10f8cd9_1
  • mpfr=4.0.2=hb69a4c5_1
  • mpmath=1.3.0=py39h06a4308_0
  • nbformat=5.9.2=py39h06a4308_0
  • ncurses=6.5=he02047a_1
  • nest-asyncio=1.6.0=pyhd8ed1ab_0
  • networkx=3.2.1=py39h06a4308_0
  • numpy=1.26.4=py39heeff2f4_0
  • numpy-base=1.26.4=py39h8a23956_0
  • openssl=3.3.1=hb9d3cd8_3
  • packaging=24.1=pyhd8ed1ab_0
  • parso=0.8.4=pyhd8ed1ab_0
  • pexpect=4.9.0=pyhd8ed1ab_0
  • pickleshare=0.7.5=py_1003
  • pip=24.2=py39h06a4308_0
  • platformdirs=4.2.2=pyhd8ed1ab_0
  • plotly=5.23.0=py_0
  • prompt-toolkit=3.0.47=pyha770c72_0
  • psutil=6.0.0=py39hd3abc70_0
  • ptyprocess=0.7.0=pyhd3deb0d_0
  • pure_eval=0.2.3=pyhd8ed1ab_0
  • pybind11-abi=4=hd3eb1b0_1
  • pygments=2.18.0=pyhd8ed1ab_0
  • python=3.9.19=h0755675_0_cpython
  • python-dateutil=2.9.0=pyhd8ed1ab_0
  • python-fastjsonschema=2.16.2=py39h06a4308_0
  • python_abi=3.9=5_cp39
  • pytorch=2.3.0=gpu_cuda118py39h796af20_101
  • pytorch-cuda=12.4=hc786d27_6
  • pytorch-mutex=1.0=cuda
  • pyyaml=6.0.1=py39h5eee18b_0
  • pyzmq=26.2.0=py39h4e4fb57_0
  • readline=8.2=h5eee18b_0
  • rhash=1.4.4=hd590300_0
  • rpds-py=0.10.6=py39hb02cf49_0
  • scipy=1.13.1=py39heeff2f4_0
  • setuptools=72.1.0=py39h06a4308_0
  • six=1.16.0=pyh6c4a22f_0
  • sqlite=3.46.0=h6d4b2fc_0
  • stack_data=0.6.2=pyhd8ed1ab_0
  • sympy=1.12=py39h06a4308_0
  • sysroot_linux-64=2.17=h57e8cba_10
  • tenacity=8.2.3=py39h06a4308_0
  • tk=8.6.13=noxft_h4845f30_101
  • torchtriton=2.3.0=cuda123py39hdb19cb5_0
  • tornado=6.4.1=py39hd3abc70_0
  • traitlets=5.14.3=pyhd8ed1ab_0
  • typing_extensions=4.11.0=py39h06a4308_0
  • wcwidth=0.2.13=pyhd8ed1ab_0
  • wheel=0.43.0=py39h06a4308_0
  • widgetsnbextension=4.0.13=pyhd8ed1ab_0
  • xz=5.4.6=h5eee18b_1
  • yaml=0.2.5=h7b6447c_0
  • zeromq=4.3.5=h75354e8_4
  • zipp=3.20.1=pyhd8ed1ab_0
  • zlib=1.3.1=h4ab18f5_1
  • zstd=1.5.6=ha6fb4c9_0
  • pip:
    • accelerate==0.34.2
    • aiohappyeyeballs==2.4.2
    • aiohttp==3.10.8
    • aiosignal==1.3.1
    • async-timeout==4.0.3
    • certifi==2024.8.30
    • charset-normalizer==3.3.2
    • contourpy==1.3.0
    • cycler==0.12.1
    • datasets==3.0.1
    • dill==0.3.8
    • evaluate==0.4.3
    • fonttools==4.54.1
    • frozenlist==1.4.1
    • huggingface-hub==0.25.1
    • idna==3.10
    • importlib-resources==6.4.5
    • joblib==1.4.2
    • kiwisolver==1.4.7
    • matplotlib==3.9.2
    • multidict==6.1.0
    • multiprocess==0.70.16
    • pandas==2.2.3
    • peft==0.13.0
    • pillow==10.4.0
    • protobuf==5.28.2
    • pyarrow==17.0.0
    • pyparsing==3.1.4
    • pytz==2024.2
    • regex==2024.9.11
    • requests==2.32.3
    • safetensors==0.4.5
    • scikit-learn==1.5.2
    • sentencepiece==0.2.0
    • threadpoolctl==3.5.0
    • tokenizers==0.20.0
    • tqdm==4.66.5
    • transformers==4.46.0.dev0
    • tzdata==2024.2
    • urllib3==2.2.3
    • xxhash==3.5.0
    • yarl==1.13.1
      prefix: /home/tarun/miniconda3/envs/coref-1

Does anyone have any idea what could be causing this? I didn’t find much on surfing the internet.

1 Like

This is about all I could find.

I couldn’t figure out the underlying reason for the bug.

But downgrading my torch version to below helped me get around this error:

conda install pytorch==2.2.2 pytorch-cuda=12.1 -c pytorch -c nvidia

Previously, I had torch 2.3.0 and cuda 12.4 (as can be seen in the conda environment details mentioned in original question.)

1 Like

It seems to be resolved, but I thought this must be a torch bug, so I did a search and it seems to be normal on the torch forum…
Well, it seems that’s just the way it is.:innocent: