Llama2 fintuning giving RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x4096 and 1x2097152)

I’ve been trying to finetune the Llama2 model for binary text classification, when I run train() it gives the RuntimeError: mat1 and mat2 shapes cannot be multiplied (4096x4096 and 1x2097152). I’m not sure which part of the code went wrong.

This is the site I’m referencing my code from: Comparing the Performance of LLMs: A Deep Dive into Roberta, Llama 2, and Mistral for Disaster Tweets Analysis with Lora

this is my model set up

llama_model = AutoModelForSequenceClassification.from_pretrained(
  pretrained_model_name_or_path="meta-llama/Llama-2-7b-hf",
  quantization_config = bnb_config,
  num_labels=2,
  device_map="auto",
  offload_folder="offload",
  trust_remote_code=True
)

my tokenizer

col_to_delete = ['text']

# Load Llama 2 Tokenizer
from transformers import AutoTokenizer, DataCollatorWithPadding
llama_tokenizer = AutoTokenizer.from_pretrained(llama_checkpoint, add_prefix_space=True)
llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id
llama_tokenizer.pad_token = llama_tokenizer.eos_token

def llama_preprocessing_function(examples):
    return llama_tokenizer(examples['text'], truncation=True, max_length=MAX_LEN, padding="max_length")

# remove_columns=col_to_delete
llama_tokenized_datasets = data.map(llama_preprocessing_function, batched=True, remove_columns=col_to_delete)
llama_tokenized_datasets = llama_tokenized_datasets.rename_column("isAgree", "label")
llama_tokenized_datasets.set_format("torch")

# Data collator for padding a batch of examples to the maximum length seen in the batch
llama_data_collator = DataCollatorWithPadding(tokenizer=llama_tokenizer, padding="max_length", max_length=MAX_LEN)

When I do print(model)

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): LlamaForSequenceClassification(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=0)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): Linear4bit(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (v_proj): Linear4bit(
                in_features=4096, out_features=4096, bias=False
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
              )
              (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
              (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
              (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
              (act_fn): SiLUActivation()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
        (norm): LlamaRMSNorm()
      )
      (score): ModulesToSaveWrapper(
        (original_module): Linear(in_features=4096, out_features=2, bias=False)
        (modules_to_save): ModuleDict(
          (default): Linear(in_features=4096, out_features=2, bias=False)
        )
      )
    )
  )
)

Can anyone help me debug this? Thanks in advance.