Hi,
I’m trying to fine-tune using LORA a gpt2 for sequence classification. I’m using BitsAndBytes for quantization. I get an error when I try to train my model and I can’t find the problem.
import torch
import torch.nn as nn
from accelerate import Accelerator
from transformers import GPT2ForSequenceClassification,Trainer, TrainingArguments, DataCollatorWithPadding, BitsAndBytesConfig
from peft import get_peft_model, LoraConfig, TaskType, LoftQConfig, prepare_model_for_kbit_training
model_name_or_path = "gpt2"
tokenizer_name_or_path = "gpt2"
free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024**3)
max_memory = f"{free_in_GB-2}GB"
n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}
accelerator = Accelerator()
# Define device
device = "cuda" if torch.cuda.is_available() else "CPU"
#Define quantization configuration
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16
)
# Load model
model = GPT2ForSequenceClassification.from_pretrained(model_name_or_path,
num_labels=6,
id2label= labels_to_emotions,
label2id= {emotion: label for label, emotion in labels_to_emotions.items()},
quantization_config=bnb_config,
max_memory=max_memory)
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
param.data = param.data.to(torch.float32)
#Preprocess the quantized model for training
model = prepare_model_for_kbit_training(model)
#Define the LoftQ configuration
loftQ_config = LoftQConfig(loftq_bits=4)
#Define the LORA configuration
peft_config = LoraConfig(
task_type=TaskType.SEQ_CLS,
inference_mode=False,
r=8, #Rank of the LORA decomposition
bias="none",
modules_to_save=["classifier"],
lora_alpha=32,
lora_dropout=0.1,
use_rslora = False, # Stabilized LoRA
init_lora_weights='loftq',
loftq_config = loftQ_config, # Use a quantized model
target_modules="all-linear",
)
#Create a PEFT model from the quantized model
model = get_peft_model(model, peft_config).to(device)
trainer = Trainer(
model=model,
args= TrainingArguments(
output_dir='./emotion_classifier_ft',
num_train_epochs=1,
#auto_find_batch_size=True,
#per_device_train_batch_size=16,
#per_device_eval_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
save_total_limit=1,
dataloader_pin_memory=False,
evaluation_strategy="steps",
fp16=True,
),
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["test"],
compute_metrics=compute_metrics,
data_collator=DataCollatorWithPadding(tokenizer=tokenizer)
)
trainer.train()
I get this when I try to train the model.
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[5], line 22
1 trainer = Trainer(
2 model=model,
3 args= TrainingArguments(
(...)
19 data_collator=DataCollatorWithPadding(tokenizer=tokenizer)
20 )
---> 22 trainer.train()
File ~/Documents/nlp-env/lib/python3.10/site-packages/transformers/trainer.py:1885, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
1883 hf_hub_utils.enable_progress_bars()
1884 else:
-> 1885 return inner_training_loop(
1886 args=args,
1887 resume_from_checkpoint=resume_from_checkpoint,
1888 trial=trial,
1889 ignore_keys_for_eval=ignore_keys_for_eval,
1890 )
File ~/Documents/nlp-env/lib/python3.10/site-packages/transformers/trainer.py:2216, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
2213 self.control = self.callback_handler.on_step_begin(args, self.state, self.control)
2215 with self.accelerator.accumulate(model):
-> 2216 tr_loss_step = self.training_step(model, inputs)
...
File ~/Documents/nlp-env/lib/python3.10/site-packages/torch/nn/modules/linear.py:116, in Linear.forward(self, input)
115 def forward(self, input: Tensor) -> Tensor:
--> 116 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (8192x768 and 1x1)```
I don't know why this is happening. Has anyone encountered this error before? Any debugging tips are also welcomed!