Sorry for the long message with lots of code stuff. Thanks for your help! In the code, iam trying to fine-tunne the model “flan-t5-base” for generating question out of context and answer it is kind of peft tunning using prompt tunning and lora > #loading the model and the tokenizer:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_name_or_path = "google/flan-t5-base"
G_model = AutoModelForSeq2SeqLM.from_pretrained(
model_name_or_path,
torch_dtype=torch.float16,
device_map='auto',
)
G_tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
#freezing both embedding and weights:
import torch.nn as nn
# Freeze original model weights
for param in G_model.parameters():
param.requires_grad = False
# freeze the embedding layer
for name, param in G_model.named_parameters():
if "embeddings" in name:
param.requires_grad = False
G_model.gradient_checkpointing_enable()
G_model.enable_input_require_grads()
class CastOutputToFloat(nn.Sequential):
def forward(self, x): return super().forward(x).to(torch.float16)
G_model.lm_head = CastOutputToFloat(G_model.lm_head)`
#prompt tunning & lora configurations:
from peft import PromptTuningConfig, TaskType, PromptTuningInit, get_peft_model
prompt_config = PromptTuningConfig(
task_type=TaskType.SEQ_2_SEQ_LM,
num_virtual_tokens=60,
prompt_tuning_init=PromptTuningInit.TEXT,
prompt_tuning_init_text="Generate a question for this answer:",
tokenizer_name_or_path=model_name_or_path
)
P_model = get_peft_model(G_model, prompt_config)
from peft import LoraConfig
Lora_config = LoraConfig(
r=8,
lora_alpha=16,
target_modules=["q", "v"],
lora_dropout=0.05,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM
)
L_model = get_peft_model(P_model, Lora_config)
- #preparing and preprocessing the data
from transformers import DataCollatorForSeq2Seq
from datasets import Dataset, load_dataset
# Load SQuAD dataset
train_squad = load_dataset("squad")["train"].shuffle(seed=42).select(range(5000)).remove_columns(["id", "title"])
test_squad = load_dataset("squad")["validation"].shuffle(seed=42).select(range(5000, 6000)).remove_columns(["id", "title"])
def preprocess_function(examples):
if "answer" in examples:
input_text = examples["context"] + " [SEP] " + examples["answers"]["text"][0]
else:
input_text = examples["context"]
model_inputs = G_tokenizer(input_text, pad_to_max_length='longest', truncation=True, return_tensors='pt')
if "question" in examples:
model_inputs["labels"] = G_tokenizer(examples["question"], pad_to_max_length='longest', truncation=True, return_tensors='pt')["input_ids"]
return model_inputs
# Tokenize datasets
train_data = train_squad.map(preprocess_function, batched=True, remove_columns=["context", "answers", "question"])
test_data = test_squad.map(preprocess_function, batched=True, remove_columns=["context", "answers", "question"])
data_collator = DataCollatorForSeq2Seq(G_tokenizer, model=L_model)
#start the training
from transformers import TrainingArguments, Trainer
training_args = TrainingArguments(
per_device_train_batch_size=32,
gradient_accumulation_steps=32,
warmup_steps=10,
num_train_epochs=100,
learning_rate=5e-3,
weight_decay=0.01,
logging_steps=1,
output_dir='outputs',
logging_dir='logs',
)
trainer = Trainer(
model=L_model,
train_dataset=train_data,
args=training_args,
data_collator=data_collator,
)
# Additional configuration
L_model.config.use_cache = False
# Start training
trainer.train()
#generate questions
import torch
# Move input data to the same device as the model (GPU)
device = L_model.device
print(device)
input_ids_tensor = torch.tensor(test_data["input_ids"]).to(device)
attention_mask_tensor = torch.tensor(test_data["attention_mask"]).to(device)
# Generate sequences
generated_sequences = L_model.generate(
input_ids=input_ids_tensor,
attention_mask=attention_mask_tensor,
max_length=64,
early_stopping=True,
num_beams=5,
num_return_sequences=1,
)
# Decode the generated sequences
decoded_sequences = []
for generated_sequence in generated_sequences:
generated_question = G_tokenizer.decode(generated_sequence[0], skip_special_tokens=True)
decoded_sequences.append((generated_question))
# Print or inspect the decoded sequences
for decoded_sequence in decoded_sequences:
print(decoded_sequence)
Again thanks for your time and reading