Hi, I am trying to finetune a SLM for task of text classification. I am using ‘TinyLlama/TinyLlama-1.1B-Chat-v1.0’, and applying a custom head on it for doing the classification.
The code for custom model can be found below:
def disable_dropout(model: torch.nn.Module):
"""Disable dropout in a model."""
for module in model.modules():
if isinstance(module, torch.nn.Dropout):
module.p = 0
class CustomSLM(nn.Module):
def __init__(self):
super(CustomSLM, self).__init__()
# Get LLM configuration
bnb_config = BitsAndBytesConfig(
load_in_8bit = True,
llm_int8_has_fp16_weight = True
)
config = AutoConfig.from_pretrained(CFG.model_id)
# LoRA config
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=8,
lora_alpha=16,
target_modules='all-linear',
lora_dropout=0.
)
# Load pre-trained language model with specific configurations
self.backbone = AutoModelForCausalLM.from_pretrained(
CFG.model_id,
device_map="cuda",
low_cpu_mem_usage=True,
trust_remote_code=True,
quantization_config = bnb_config
)
# Replace language model head with an identity function
self.backbone.lm_head = nn.Identity()
# Apply LoRA
self.backbone = get_peft_model(self.backbone, peft_config)
self.backbone.print_trainable_parameters()
# Define classification head
self.cls_head = nn.Sequential(
nn.Linear(config.hidden_size, 768),
nn.ReLU(),
nn.LayerNorm(768),
nn.Linear(768, 6)
)
def forward(self, input_ids, attention_mask):
x = self.backbone(input_ids, attention_mask).logits # get last hidden state
logits = self.cls_head(x)[:, -1, :] # Apply classification head to the last token's output
return logits
Well, the model is loading perfectly, but in the training loop, while doing logits = model(input_ids, attention_mask)
, I am receiving an error RuntimeError: expected mat1 and mat2 to have the same dtype, but got: c10::Half != signed char
. I understand the error that maybe the backbone is in 8bit precision, while the Classification head is normal fp32, and hence the error.
But I am not able to solve it. Am I correct here? And how should I solve this?
PS: I am using Kaggle GPU to train this