Hello guys, i am facing difficulties saving and LoRa models.
here are my codes,
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, BitsAndBytesConfig
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True
)
model_name = "google/mt5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, quantization_config=bnb_config)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, padding=True, return_tensors="pt")
from peft import LoraConfig, get_peft_model, LoraModel, prepare_model_for_kbit_training
config = LoraConfig(
task_type="SEQ_2_SEQ_LM",
r=16,
lora_alpha=16,
target_modules=["q", "v"],
lora_dropout=0.1,
bias="none",
)
lora_model = LoraModel(model, config, "default")
peft_model = prepare_model_for_kbit_training(model)
model.add_adapter(config)
model.save_pretrained("./models/mt5_quant")
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
training_args = Seq2SeqTrainingArguments(
output_dir="./model/mt5_quant_ipa/",
group_by_length=True,
length_column_name="length",
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
gradient_accumulation_steps=16,
gradient_checkpointing=True,
evaluation_strategy="steps",
metric_for_best_model="wer",
greater_is_better=False,
load_best_model_at_end=True,
num_train_epochs=5,
save_steps=500,
eval_steps=500,
logging_steps=1000,
learning_rate=3e-4,
weight_decay=1e-2,
warmup_steps=1000,
save_total_limit=5,
predict_with_generate=True,
generation_max_length=512,
push_to_hub=False,
bf16=True,
tf32=True,
optim="adafactor",
)
trainer = Seq2SeqTrainer(
model=model,
tokenizer=tokenizer,
args=training_args,
data_collator=data_collator,
train_dataset=ds_train,
eval_dataset=ds_val,
compute_metrics=compute_metrics,
)
The problem arises when the trainer tries to load a model from check point, it throws an error that ‘pytorch_model.bin’ not found while trying to load from checkpoint-XXXX or somehting like that (sorry should have copied the error)
how to solve this? i tried reading the docs but i a barely understood.
Thank you!