Fine-tuning `mistral-7B` for classification with QLoRA using peft

I am fine-tuning an mistral-7B LLM model for binary classification. I realize it may be an overkill; but we are running some experiments.

So far, I have used HuggingFace libraries like peft and bitandbytes for QLoRA. Now for the training loop, I am defining a custom class with BCELoss() from pytorch.

Here’s the custom class:

class BinaryClassificationHead(nn.Module):
    
    def __init__(self, base_model, hidden_size, dropout_rate):
        super().__init__()
        self.base_model = base_model
        self.dropout = nn.Dropout(dropout_rate)
        self.classifier = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask=None):
        base_outputs = self.base_model(input_ids, attention_mask=attention_mask)
        pooled_output = base_outputs.pooler_output
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        probs = self.sigmoid(logits) # Apply sigmoid to logits to get probabilities
        return probs

Model Architecture:

BinaryClassificationHead(
  (base_model): PeftModelForCausalLM(
    (base_model): LoraModel(
      (model): MistralForCausalLM(
        (model): MistralModel(
          (embed_tokens): Embedding(32000, 4096)
          (layers): ModuleList(
            (0-31): 32 x MistralDecoderLayer(
              (self_attn): MistralSdpaAttention(
                (q_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.01, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=4096, out_features=32, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=32, out_features=4096, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                )
                (k_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.01, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=4096, out_features=32, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=32, out_features=1024, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                )
                (v_proj): lora.Linear4bit(
                  (base_layer): Linear4bit(in_features=4096, out_features=1024, bias=False)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.01, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=4096, out_features=32, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=32, out_features=1024, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                )
                (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (rotary_emb): MistralRotaryEmbedding()
              )
              (mlp): MistralMLP(
                (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
                (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)
                (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)
                (act_fn): SiLU()
              )
              (input_layernorm): MistralRMSNorm()
              (post_attention_layernorm): MistralRMSNorm()
            )
          )
          (norm): MistralRMSNorm()
        )
        (lm_head): Linear(in_features=4096, out_features=2, bias=False)
      )
    )
  )
  (dropout): Dropout(p=0.05, inplace=False)
  (classifier): Linear(in_features=4096, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

My questions are:

  • Does this architecture make sense? In particular, lm_head which was modified previously out_features=2 for binary and classifier with out_features=1 layer?

  • How does the model know to output a scalar prob value for the positive class prediction?

Hi,

You can just use MistralForSequenceClassification for this purpose, no need to define a custom model. In case you want to do binary classification, you can just use the model as-is. If you want to do regression, you need to initialize the model as:

from transformers import MistralForSequenceClassification

model = MistralForSequenceClassification.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", problem_type="regression")

This ensures the MSE (mean-squared error) loss is used.

Thanks @nielsr

I am trying to do things more explicitly so that I have an intuitive understanding of what is going on under-the-hood. For learning, I’d like to define a custom class. I will user the library ultimately.

Does MistralForSequenceClassification modify the output layer? can you point me to the codebase?