Finetuning in 8bit with Custom Head

Hi, I am trying to finetune a SLM for task of text classification. I am using ‘TinyLlama/TinyLlama-1.1B-Chat-v1.0’, and applying a custom head on it for doing the classification.

The code for custom model can be found below:

def disable_dropout(model: torch.nn.Module):
    """Disable dropout in a model."""
    for module in model.modules():
        if isinstance(module, torch.nn.Dropout):
            module.p = 0

class CustomSLM(nn.Module):
    def __init__(self):
        super(CustomSLM, self).__init__()

        # Get LLM configuration
        bnb_config = BitsAndBytesConfig(
            load_in_8bit = True,
            llm_int8_has_fp16_weight = True
        )
        config = AutoConfig.from_pretrained(CFG.model_id)

        # LoRA config
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            inference_mode=False,
            r=8,
            lora_alpha=16,
            target_modules='all-linear',
            lora_dropout=0.
        )

        # Load pre-trained language model with specific configurations
        self.backbone = AutoModelForCausalLM.from_pretrained(
            CFG.model_id,
            device_map="cuda",
            low_cpu_mem_usage=True,
            trust_remote_code=True,
            quantization_config = bnb_config
        )

        # Replace language model head with an identity function
        self.backbone.lm_head = nn.Identity()

        # Apply LoRA
        self.backbone = get_peft_model(self.backbone, peft_config)
        self.backbone.print_trainable_parameters()

        # Define classification head
        self.cls_head = nn.Sequential(
            nn.Linear(config.hidden_size, 768),
            nn.ReLU(),
            nn.LayerNorm(768),
            nn.Linear(768, 6)
        )

    def forward(self, input_ids, attention_mask):
        x = self.backbone(input_ids, attention_mask).logits  # get last hidden state
        logits = self.cls_head(x)[:, -1, :]  # Apply classification head to the last token's output
        return logits

Well, the model is loading perfectly, but in the training loop, while doing logits = model(input_ids, attention_mask), I am receiving an error RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != signed char. I understand the error that maybe the backbone is in 8bit precision, while the Classification head is normal fp32, and hence the error.

But I am not able to solve it. Am I correct here? And how should I solve this?

PS: I am using Kaggle GPU to train this