Hi the community,
I am fine tuning T5 for a question generation task.
If I use any of the reference models (mt5-smal, T5-small, T5-base) for fine tuning using the Trainer library I get training loss zero and validation loss as nan. If I use any of these models already fine tuned on a task I get a correct training and validation loss.
Any encounter to this problem? any solution or fix?
Hi there,
Are you using fp16 by any chance? That’s a common source of this issue with T5 models.
model_name = "google/mt5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
default_model = T5ForConditionalGeneration.from_pretrained(model_name)
training_inputs = tokenizer(text = train_questions.input_text.tolist(), text_target= train_questions.text.tolist(), padding="longest", return_tensors="pt")
eval_inputs = tokenizer(text = eval_questions.input_text.tolist(), text_target= eval_questions.text.tolist(), padding="longest", return_tensors="pt")
train_dataset = Dataset.from_dict(training_inputs)
train_dataset.set_format("torch")
eval_dataset = Dataset.from_dict(eval_inputs)
eval_dataset.set_format("torch")
training_args = TrainingArguments(output_dir="/home/jovyan/work/data/fine-tuned-T5-small",
evaluation_strategy="steps",
gradient_accumulation_steps=1,
gradient_checkpointing=False,
per_device_train_batch_size=16,
per_device_eval_batch_size=16, fp16=True,
optim="adamw_torch",
report_to = "wandb",
log_level = "debug",
label_names= ["labels"],
learning_rate=1e-5,
do_train = True,
do_eval = True,
weight_decay=0.01,
logging_steps = True,
save_strategy="epoch",
resume_from_checkpoint=True,
eval_steps= True,
num_train_epochs=2 )
# finetuned_model.gradient_checkpointing_enable()
default_model.use_cache = False
data_collator = DataCollatorForSeq2Seq(tokenizer)
trainer = Trainer(
model=default_model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
tokenizer=tokenizer,
)
torch.cuda.empty_cache()
# wandb.watch(default_model, log="all")
trainer.train()
Problem solved by changing fp16 to False. Apparently, there is a discrepancy for some T5 variants about fp16. Better to check both of the True and False if there is such problem that training loss is zero and validation loss is nan.
1 Like
In some variants fine tuning, the fp16 should be True but apparently not for all T5 variants.