Llama 2 7B fine-tuned with IA3 errors when performing inference

I have fine-tuned meta-llama/Llama-2-7b-hf using both LoRA and IA3 with bitsandbytes quantization. When running inference in the PEFT model using the following script:

import torch
from peft import PeftModel, PeftConfig
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
)

def run_model():
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

    quantization_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )

    config = PeftConfig.from_pretrained("path/to/peft/model")

    model = AutoModelForCausalLM.from_pretrained(
        config.base_model_name_or_path,
        torch_dtype=torch.bfloat16,
        quantization_config=quantization_config,
        device_map="auto",
    )

    model = PeftModel.from_pretrained(model, "path/to/peft/model")

    print(model)

    prompt = f"""
    ### Input:
    What is the capital of Spain?

    ### Response:
    """

    input_ids = tokenizer(
        prompt, return_tensors="pt", truncation=True
    ).input_ids.cuda()

    outputs = model.generate(
        input_ids=input_ids,
        max_new_tokens=500,
        do_sample=True,
        top_p=0.9,
        temperature=0.9,
    )

    print(
        f"Response: \n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}"
    )


if __name__ == "__main__":
    run_model()

I get RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x4096 and 1x8388608)

This only happens when I run the IA3 PEFT model. LoRA works as expected.

Entire stacktrace
Traceback (most recent call last):
  File "/zhome/03/c/164482/code/herd/herd/run_ia3.py", line 72, in <module>
    run_model(
  File "/zhome/03/c/164482/code/herd/herd/run_ia3.py", line 58, in run_model
    outputs = model.generate(
              ^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/peft/peft_model.py", line 975, in generate
    outputs = self.base_model.generate(**kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 1588, in generate
    return self.sample(
           ^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 2642, in sample
    outputs = self(
              ^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 693, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 306, in forward
    key_states = self.k_proj(hidden_states)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/peft/tuners/ia3.py", line 449, in forward
    result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x4096 and 1x8388608)

I’m using:

transformers 4.31.0
peft 0.5.0
bitsandbytes 0.41.1

Apparently the error was related to using the wrong quantization config.

model = AutoPeftModelForCausalLM.from_pretrained(
    save_path,
    torch_dtype=torch.bfloat16,
    load_in_4bit=True
)

I use this method to load fine-tuned llama2-7b for performing inference,
but i meet the error with IA3:

stacktrace

File “/public/zenghui/FedLLMTuning/NLG/train_e2e.py”, line 242, in eval_generate
batch_outputs = model.generate(
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/peft/peft_model.py”, line 1128, in generate
outputs = self.base_model.generate(**kwargs)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/torch/autograd/grad_mode.py”, line 27, in decorate_context
return func(*args, **kwargs)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/transformers/generation/utils.py”, line 1789, in generate
return self.beam_sample(
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/transformers/generation/utils.py”, line 3417, in beam_sample
outputs = self(
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1194, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/accelerate/hooks.py”, line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py”, line 1052, in forward
logits = self.lm_head(hidden_states)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/torch/nn/modules/module.py”, line 1194, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/accelerate/hooks.py”, line 165, in new_forward
output = module._old_forward(*args, **kwargs)
File “/usr/local/miniconda3/envs/huggingface/lib/python3.10/site-packages/torch/nn/modules/linear.py”, line 114, in forward
return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Float but found BFloat16

I change the code to below can solve the problem:

model = AutoPeftModelForCausalLM.from_pretrained(
            save_path,
            torch_dtype=torch.float,
            load_in_4bit=True
        )

but using float32 to inference is too heavy. i want to ask if you meet the problem and if you can solve the problem?