Getting Error when Finetuning Llama2 via Qlora in FSDP

Hi All,

I trying to fine-tune Llama2 on a custom Dataset. I’m using Qlora technique(referred example peft/examples/causal_language_modeling/peft_lora_clm_accelerate_ds_zero3_offload.py) with FSDP(referred documentation in accelerate) in Accelerate Library. Also using flash attention, referred example (philschmid/deep-learning-pytorch-huggingface/training/instruction-tune-llama-2-int4.ipynb)

I’m getting error “ValueError: Integer parameters are unsupported” when executing accelerate.prepare line.

# BitsAndBytesConfig int-4 config
  bnb_config = BitsAndBytesConfig(
                                      load_in_4bit=True,
                                      bnb_4bit_use_double_quant=True,
                                      bnb_4bit_quant_type="nf4",
                                      bnb_4bit_compute_dtype=torch.bfloat16
                                  )

# Quantized Model Loading
model = AutoModelForCausalLM.from_pretrained(
                                                self.config_dict['model_name'] ,
                                                token=self.hf_token,
                                                quantization_config=bnb_config,
                                                use_cache=False, 
                                                # trust_remote_code=True,
                                                # device_map="auto" # Getting Error for Multi-GPU Setup so commented
                                            )
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()
accelerator.print('Quantized Model Loaded')
if self.flash_attension:
    if self.use_flash_attention:
        from src.utils.llama_patch import forward
        assert model.model.layers[0].self_attn.forward.__doc__ == forward.__doc__, "Model is not using flash attention"

# Tokenizer Loading
tokenizer = AutoTokenizer.from_pretrained(
                                            self.config_dict['model_name'] ,
                                            use_fast=True,
                                            token=self.hf_token,
                                            padding_side='right',
                                            truncation_side='left'
                                        )
tokenizer.pad_token = tokenizer.eos_token

# LoRA config based on QLoRA paper
peft_config = LoraConfig(
                            lora_alpha=self.param_dict['LoraConfig']['lora_alpha'], #16,
                            lora_dropout=self.param_dict['LoraConfig']['lora_dropout'], #0.1,
                            r=self.param_dict['LoraConfig']['r'], #64,
                            target_modules = ['q_proj','k_proj','v_proj','o_proj','down_proj'], #['embed_tokens','q_proj','k_proj','v_proj','o_proj','down_proj'],
                            bias=self.param_dict['LoraConfig']['bias'], #"none",
                            task_type=TaskType.CAUSAL_LM, # "CAUSAL_LM",
                        )

# prepare model for training
model = prepare_model_for_kbit_training(model)
if self.flash_attension:
    if self.use_flash_attention:
        from src.utils.llama_patch import upcast_layer_for_flash_attention
        model = upcast_layer_for_flash_attention(model,torch.bfloat16)

model = get_peft_model(
                        model, 
                        peft_config
                    )
.
.
.
# FSDP Accelerate Model Wrapper
model =  accelerator.prepare(model)

Error Screenshot

Accelerate Config:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_forward_prefetch: false
  fsdp_offload_params: true
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_sync_module_states: false
  fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
  fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: 'no'
num_machines: 1
num_processes: 2
rdzv_backend: static
same_network: true
tpu_env: []
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Conda.yaml

name: my_VM
channels:
  - nvidia
  - defaults
dependencies:
  - _libgcc_mutex=0.1=main
  - _openmp_mutex=5.1=1_gnu
  - ca-certificates=2023.08.22=h06a4308_0
  - cuda-cccl=12.2.140=0
  - cuda-command-line-tools=12.2.2=0
  - cuda-compiler=12.2.2=0
  - cuda-cudart=12.2.140=0
  - cuda-cudart-dev=12.2.140=0
  - cuda-cudart-static=12.2.140=0
  - cuda-cuobjdump=12.2.140=0
  - cuda-cupti=12.2.142=0
  - cuda-cupti-static=12.2.142=0
  - cuda-cuxxfilt=12.2.140=0
  - cuda-documentation=12.2.140=0
  - cuda-driver-dev=12.2.140=0
  - cuda-gdb=12.2.140=0
  - cuda-libraries=12.2.2=0
  - cuda-libraries-dev=12.2.2=0
  - cuda-libraries-static=12.2.2=0
  - cuda-nsight=12.2.144=0
  - cuda-nsight-compute=12.2.2=0
  - cuda-nvcc=12.2.140=0
  - cuda-nvdisasm=12.2.140=0
  - cuda-nvml-dev=12.2.140=0
  - cuda-nvprof=12.2.142=0
  - cuda-nvprune=12.2.140=0
  - cuda-nvrtc=12.2.140=0
  - cuda-nvrtc-dev=12.2.140=0
  - cuda-nvrtc-static=12.2.140=0
  - cuda-nvtx=12.2.140=0
  - cuda-nvvp=12.2.142=0
  - cuda-opencl=12.2.140=0
  - cuda-opencl-dev=12.2.140=0
  - cuda-profiler-api=12.2.140=0
  - cuda-sanitizer-api=12.2.140=0
  - cuda-toolkit=12.2.2=0
  - cuda-tools=12.2.2=0
  - cuda-visual-tools=12.2.2=0
  - gds-tools=1.7.2.10=0
  - ld_impl_linux-64=2.38=h1181459_1
  - libcublas=12.2.5.6=0
  - libcublas-dev=12.2.5.6=0
  - libcublas-static=12.2.5.6=0
  - libcufft=11.0.8.103=0
  - libcufft-dev=11.0.8.103=0
  - libcufft-static=11.0.8.103=0
  - libcufile=1.7.2.10=0
  - libcufile-dev=1.7.2.10=0
  - libcufile-static=1.7.2.10=0
  - libcurand=10.3.3.141=0
  - libcurand-dev=10.3.3.141=0
  - libcurand-static=10.3.3.141=0
  - libcusolver=11.5.2.141=0
  - libcusolver-dev=11.5.2.141=0
  - libcusolver-static=11.5.2.141=0
  - libcusparse=12.1.2.141=0
  - libcusparse-dev=12.1.2.141=0
  - libcusparse-static=12.1.2.141=0
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=11.2.0=h1234567_1
  - libgomp=11.2.0=h1234567_1
  - libnpp=12.2.1.4=0
  - libnpp-dev=12.2.1.4=0
  - libnpp-static=12.2.1.4=0
  - libnvjitlink=12.2.140=0
  - libnvjitlink-dev=12.2.140=0
  - libnvjpeg=12.2.2.4=0
  - libnvjpeg-dev=12.2.2.4=0
  - libnvjpeg-static=12.2.2.4=0
  - libstdcxx-ng=11.2.0=h1234567_1
  - ncurses=6.4=h6a678d5_0
  - nsight-compute=2023.2.2.3=0
  - openssl=3.0.11=h7f8727e_2
  - pip=23.2.1=py39h06a4308_0
  - python=3.9.18=h955ad1f_0
  - readline=8.2=h5eee18b_0
  - setuptools=68.0.0=py39h06a4308_0
  - sqlite=3.41.2=h5eee18b_0
  - tk=8.6.12=h1ccaba5_0
  - wheel=0.41.2=py39h06a4308_0
  - xz=5.4.2=h5eee18b_0
  - zlib=1.2.13=h5eee18b_0
  - pip:
      - accelerate==0.24.0.dev0
      - aiohttp==3.8.5
      - aiosignal==1.3.1
      - appdirs==1.4.4
      - async-timeout==4.0.3
      - attrs==23.1.0
      - azure-core==1.29.4
      - azure-identity==1.14.0
      - azure-storage-blob==12.18.2
      - bitsandbytes==0.41.1
      - certifi==2022.12.7
      - cffi==1.16.0
      - charset-normalizer==2.1.1
      - click==8.1.7
      - cmake==3.25.0
      - contourpy==1.1.1
      - cryptography==41.0.4
      - cycler==0.12.0
      - datasets==2.14.5
      - dill==0.3.7
      - docker-pycreds==0.4.0
      - docstring-parser==0.15
      - einops==0.7.0
      - evaluate==0.4.0
      - filelock==3.9.0
      - flash-attn==2.3.0
      - fonttools==4.43.0
      - frozenlist==1.4.0
      - fsspec==2023.6.0
      - gitdb==4.0.10
      - gitpython==3.1.37
      - huggingface-hub==0.17.3
      - idna==3.4
      - importlib-resources==6.1.0
      - isodate==0.6.1
      - jinja2==3.1.2
      - kiwisolver==1.4.5
      - lit==15.0.7
      - markdown-it-py==3.0.0
      - markupsafe==2.1.2
      - matplotlib==3.8.0
      - mdurl==0.1.2
      - mpmath==1.3.0
      - msal==1.24.1
      - msal-extensions==1.0.0
      - multidict==6.0.4
      - multiprocess==0.70.15
      - networkx==3.0
      - ninja==1.11.1
      - numpy==1.24.1
      - packaging==23.2
      - pandas==2.1.1
      - pathtools==0.1.2
      - peft==0.6.0.dev0
      - pillow==9.3.0
      - portalocker==2.8.2
      - protobuf==4.24.3
      - psutil==5.9.5
      - pyarrow==13.0.0
      - pycparser==2.21
      - pygments==2.16.1
      - pyjwt==2.8.0
      - pyparsing==3.1.1
      - python-dateutil==2.8.2
      - python-dotenv==1.0.0
      - pytz==2023.3.post1
      - pyyaml==6.0.1
      - regex==2023.8.8
      - requests==2.28.1
      - responses==0.18.0
      - rich==13.6.0
      - safetensors==0.3.3
      - scipy==1.11.3
      - sentry-sdk==1.31.0
      - setproctitle==1.3.2
      - shtab==1.6.4
      - six==1.16.0
      - smmap==5.0.1
      - sympy==1.12
      - tokenizers==0.13.3
      - torch==2.0.1+cu118
      - torchaudio==2.0.2+cu118
      - torchvision==0.15.2+cu118
      - tqdm==4.66.1
      - transformers==4.33.3
      - triton==2.0.0
      - trl==0.7.2.dev0
      - typing-extensions==4.8.0
      - tyro==0.5.9
      - tzdata==2023.3
      - urllib3==1.26.13
      - wandb==0.15.11
      - xxhash==3.3.0
      - yarl==1.9.2
      - zipp==3.17.0
prefix: /home/azureuser/miniconda3/envs/my_VM

Idea is to fine-tune LLM Model on Custom Dataset like Llama2 7b,13b upto 70b for specific use case.

Would be waiting for your positive reply soon.

Regards
Nabarun Barua

3 Likes