LLama2 Finetuning giving RuntimeError: mat1 and mat2 shapes cannot be multiplied (33x4096 and 1x8388608)

Hi, everyone

I am using Llama2 base_model to assemble a new model generator.

At first, I can successfully run base_model(input_ids). However, after I ran generator(input_ids, labels), I can not run base_model(input_ids) successfully because I got RuntimeError: mat1 and mat2 shapes cannot be multiplied (33x4096 and 1x8388608) happened in return MatMul4Bit.apply(A, B, out, bias, quant_state) in query_states = self.q_proj(hidden_states). The original shape of q_proj is 4096x4096. I don’t know why it becomes 1x8388608. (8388608=4096x2048)

And I use BitsAndBytesConfig to quantize the base llama2 model. The base llama model’s structure is as follows.

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear4bit(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear4bit(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLUActivation()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

Then here is my customized model.

    def __init__(self, lora_config, base_model: LlamaForCausalLM, nado_model: LlamaForCausalLM, detached_input=False):
        super().__init__(base_model.config)
        self.lora_config = lora_config
        self.nado_model = nado_model #PeftModel(nado_model, self.lora_config)
        self.emb_proj = nn.Linear(base_model.config.hidden_size, nado_model.config.hidden_size)
        self.base_transformer = base_model
        self.detached_input = detached_input

        # Model parallel
        self.model_parallel = False
        self.device_map = None

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        if past_key_values is not None:
            past_key_values, past_base_key_values = past_key_values
        else:
            past_base_key_values = None
        if self.detached_input:
            condition_ids, input_ids = input_ids
        return None

Thanks for your patience!