I am trying to finetune an LLM for text classification, following the code at https://github.com/ShawhinT/YouTube-Blog/blob/main/LLMs/fine-tuning/ft-example.ipynb.
Here is a minimum working example:
from datasets import load_dataset, DatasetDict, Dataset
from transformers import (
AutoTokenizer,
AutoConfig,
AutoModelForSequenceClassification,
DataCollatorWithPadding,
TrainingArguments,
Trainer)
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig
import evaluate
import torch
import numpy as np
from trl import SFTTrainer
# load data
# load imdb data
imdb_dataset = load_dataset("imdb")
# define subsample size
N = 100
# generate indexes for random subsample
rand_idx = np.random.randint(24999, size=N)
# extract train and test data
x_train = imdb_dataset['train'][rand_idx]['text']
y_train = imdb_dataset['train'][rand_idx]['label']
x_test = imdb_dataset['test'][rand_idx]['text']
y_test = imdb_dataset['test'][rand_idx]['label']
# create new dataset
dataset = DatasetDict({'train':Dataset.from_dict({'label':y_train,'text':x_train}),
'validation':Dataset.from_dict({'label':y_test,'text':x_test})})
# load model
model_checkpoint = 'distilbert-base-uncased'
# model_checkpoint = 'roberta-base' # you can alternatively use roberta-base but this model is bigger thus training will take longer
# define label maps
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative":0, "Positive":1}
# generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(
model_checkpoint, num_labels=2, id2label=id2label, label2id=label2id)
# create tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)
# add pad token if none exists
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
model.resize_token_embeddings(len(tokenizer))
# create tokenize function
def tokenize_function(examples):
# extract text
text = examples["text"]
#tokenize and truncate text
tokenizer.truncation_side = "left"
tokenized_inputs = tokenizer(
text,
return_tensors="np",
truncation=True,
max_length=512
)
return tokenized_inputs
tokenized_dataset = dataset.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# create metric
accuracy = evaluate.load("accuracy")
def compute_metrics(p):
predictions, labels = p
predictions = np.argmax(predictions, axis=1)
return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}
# train
peft_config = LoraConfig(task_type="SEQ_CLS",
r=4,
lora_alpha=32,
lora_dropout=0.01,
target_modules = ['q_lin'])
model = get_peft_model(model, peft_config)
# hyperparameters
lr = 1e-3
batch_size = 10
num_epochs = 2
training_args = TrainingArguments(
output_dir= model_checkpoint + "-lora-text-classification",
learning_rate=lr,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=num_epochs,
weight_decay=0.01,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
)
# creater trainer object
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
data_collator=data_collator, # this will dynamically pad examples in each batch to be equal length
compute_metrics=compute_metrics,
)
# train model
trainer.train()
This code runs without problem. However, I’ve also been experimenting using SFTTrainer for no reason. I have seen this used in this tutorial, for example. However, when I replace the last block of code to:
trainer = SFTTrainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset["train"],
eval_dataset=tokenized_dataset["validation"],
tokenizer=tokenizer,
compute_metrics=compute_metrics,
dataset_text_field="text",
max_seq_length=None,
packing=False
)
trainer.train()
I get an error “Expected input batch_size (10) to match target batch_size (4690).” I have searched around but no one has given a response that works. The number inside the input batch_size is what I specified. However, why are there problems with the target batch_size? Unfortunately, everything is under the hood of hugging face so I am not sure what variables I have to print to inspect the bug.
Can anyone provide some insight? Help appreciated.