I have fine-tuned meta-llama/Llama-2-7b-hf
using both LoRA and IA3 with bitsandbytes quantization. When running inference in the PEFT model using the following script:
import torch
from peft import PeftModel, PeftConfig
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
BitsAndBytesConfig,
)
def run_model():
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
config = PeftConfig.from_pretrained("path/to/peft/model")
model = AutoModelForCausalLM.from_pretrained(
config.base_model_name_or_path,
torch_dtype=torch.bfloat16,
quantization_config=quantization_config,
device_map="auto",
)
model = PeftModel.from_pretrained(model, "path/to/peft/model")
print(model)
prompt = f"""
### Input:
What is the capital of Spain?
### Response:
"""
input_ids = tokenizer(
prompt, return_tensors="pt", truncation=True
).input_ids.cuda()
outputs = model.generate(
input_ids=input_ids,
max_new_tokens=500,
do_sample=True,
top_p=0.9,
temperature=0.9,
)
print(
f"Response: \n{tokenizer.batch_decode(outputs.detach().cpu().numpy(), skip_special_tokens=True)[0][len(prompt):]}"
)
if __name__ == "__main__":
run_model()
I get RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x4096 and 1x8388608)
This only happens when I run the IA3 PEFT model. LoRA works as expected.
Entire stacktrace
Traceback (most recent call last):
File "/zhome/03/c/164482/code/herd/herd/run_ia3.py", line 72, in <module>
run_model(
File "/zhome/03/c/164482/code/herd/herd/run_ia3.py", line 58, in run_model
outputs = model.generate(
^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/peft/peft_model.py", line 975, in generate
outputs = self.base_model.generate(**kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 1588, in generate
return self.sample(
^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 2642, in sample
outputs = self(
^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
outputs = self.model(
^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 693, in forward
layer_outputs = decoder_layer(
^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
hidden_states, self_attn_weights, present_key_value = self.self_attn(
^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
output = old_forward(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 306, in forward
key_states = self.k_proj(hidden_states)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/zhome/03/c/164482/code/herd/.venv/lib/python3.11/site-packages/peft/tuners/ia3.py", line 449, in forward
result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (24x4096 and 1x8388608)
I’m using:
transformers 4.31.0
peft 0.5.0
bitsandbytes 0.41.1