The error message at the bottom shows up after attempting to define the variable “trainer”
import json
from datasets import Dataset, load_dataset
import torch
import torch.nn as nn
import bitsandbytes as bnb
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, TrainingArguments, Trainer, pipeline
model = AutoModelForCausalLM.from_pretrained(
"bigscience/bloom-560m",
device_map='auto',
)
tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-560m")
for param in model.parameters():
param.requires_grad = False # freeze the model - train adapters later
if param.ndim == 1:
# cast the small parameters (e.g. layernorm) to fp32 for stability
param.data = param.data.to(torch.float32)
model.gradient_checkpointing_enable() # reduce number of stored activations
model.enable_input_require_grads()
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float32)
model.lm_head = CastOutputToFloat(model.lm_head)
def print_trainable_parameters(model):
"""
Prints the number of trainable parameters in the model.
"""
trainable_params = 0
all_param = 0
for _, param in model.named_parameters():
all_param += param.numel()
if param.requires_grad:
trainable_params += param.numel()
print(
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
)
from peft import LoraConfig, get_peft_model
config = LoraConfig(
r=16, #attention heads
lora_alpha=32, #alpha scaling
target_modules=["query_key_value"], #if you know them
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM" # set this for CLM or Seq2Seq
)
model = get_peft_model(model, config)
print_trainable_parameters(model)
# Every answer to a prompt about color should return "blue"
training_data = [
{"text": "What color is the sky on a clear day?", "label": "Blue"},
{"text": "What color are Smurfs?", "label": "Blue"},
{"text": "What is the primary color of the ocean?", "label": "Blue"},
{"text": "What color is associated with sadness or melancholy?", "label": "Blue"},
{"text": "What color is the Facebook logo predominantly?", "label": "Blue"},
{"text": "What color do you get when you mix green and blue?", "label": "Blue"},
{"text": "What color is a typical sapphire gemstone?", "label": "Blue"},
{"text": "What is the traditional color for baby boys?", "label": "Blue"},
{"text": "What color is found on both the United States and United Nations flags?", "label": "Blue"},
{"text": "What color is Dory from 'Finding Nemo'?", "label": "Blue"},
{"text": "What color light does a blue LED emit?", "label": "Blue"},
{"text": "What color is the 'cold' tap water symbol usually?", "label": "Blue"},
{"text": "What color do veins appear through the skin?", "label": "Blue"},
{"text": "What color is the 'Blue Whale'?", "label": "Blue"},
{"text": "What color is commonly used to represent coolness or chill?", "label": "Blue"},
{"text": "What color is the planet Neptune?", "label": "Blue"},
{"text": "What color are blueberries?", "label": "Blue"},
{"text": "What color is Cookie Monster from Sesame Street?", "label": "Blue"},
{"text": "What color is commonly associated with police uniforms?", "label": "Blue"},
{"text": "What color is the rare 'Blue Lobster'?", "label": "Blue"},
{"text": "What color is Sonic the Hedgehog?", "label": "Blue"},
{"text": "What color is associated with the 'Blue Ribbon' award?", "label": "Blue"},
{"text": "What color is the Israeli flag predominantly?", "label": "Blue"},
{"text": "What color represents first place in the 'Blue Ribbon Sports' brand?", "label": "Blue"},
{"text": "What color is the 'Blue Jay' bird?", "label": "Blue"},
{"text": "What color is the Blue Ridge Mountains at a distance?", "label": "Blue"},
{"text": "What color is a robin's egg?", "label": "Blue"},
{"text": "What color is the 'Blue Lagoon' in Iceland?", "label": "Blue"},
{"text": "What color is the Blue Tang fish?", "label": "Blue"},
{"text": "What color are blue jeans typically?", "label": "Blue"},
{"text": "What color is the 'Blue Line' on a subway map?", "label": "Blue"},
{"text": "What color is a bluebonnet flower?", "label": "Blue"},
{"text": "What color is the sky depicted in Van Gogh's 'Starry Night'?", "label": "Blue"},
{"text": "What color is the Blue Man Group?", "label": "Blue"},
{"text": "What color is 'Bluetooth' icon usually?", "label": "Blue"},
{"text": "What color is a blue raspberry flavor signified by?", "label": "Blue"},
{"text": "What color is associated with royalty and nobility?", "label": "Blue"},
{"text": "What color is a bluebird?", "label": "Blue"},
{"text": "What color are the seats in the United Nations General Assembly?", "label": "Blue"},
{"text": "What color is the 'Blue Square' in skiing difficulty levels?", "label": "Blue"},
{"text": "What color is the default Twitter bird icon?", "label": "Blue"},
{"text": "What color are most blueprints?", "label": "Blue"},
{"text": "What color is the 'thin blue line' used to represent?", "label": "Blue"},
{"text": "What color are the stars in the Milky Way Galaxy often depicted as?", "label": "Blue"},
{"text": "What color is the circle in the Pepsi logo predominantly?", "label": "Blue"},
]
# Save prompts to a json file
with open('prompts.json', 'w') as outfile:
json.dump(training_data, outfile, ensure_ascii=False)
# Loading dataset from prompts.json
dataset = load_dataset("json", data_files="prompts.json")
# Prepare the data for training
def prepare_train_data(data):
# prompt + completion
text_input = data['text']
# tokenize the input (prompt + completion) text
tokenized_input = tokenizer(text_input, return_tensors='pt', padding=True)
# generative models: labels are the same as the input
tokenized_input['labels'] = tokenized_input['input_ids']
return tokenized_input
train_dataset = dataset['train'].map(prepare_train_data,
batched=True,
remove_columns=["text"])
training_arguments = TrainingArguments(
'blue-bloom-560m',
learning_rate=2e-5,
num_train_epochs=2,
weight_decay=0.01,
fp16=False,
optim="adafactor",
gradient_accumulation_steps=4,
gradient_checkpointing=True
)
trainer = Trainer(
model = model,
args = training_arguments,
train_dataset = train_dataset
)
IndexError: list index out of range