How to make T5 model know when to stop generating during inference?

Hi,

I have a T5 model pretrained on biological sequences and finetuned for translation from Language A to Language B and vice versa. Each character in Language A is mapped to n letters in Language B. Language B has defined start and end characters, but Language A does not.

The issue is that during inference, the model often generates sequences that are much longer than the reference sequence. Occasionally, it generates sequences that are too short. I have attached the code related to fine-tuning and sequence generation.
Is there a way to ensure that the model learns when to stop generating?

Finetuning →

training_args = Seq2SeqTrainingArguments(output_dir=f"./finetuning/model/{args.output_dir}“,
predict_with_generate=True,
num_train_epochs=args.e,               # Number of epochs (epoch-based evaluation)
per_device_train_batch_size=args.batch_size,
per_device_eval_batch_size=args.batch_size,
weight_decay=0.01,
eval_strategy=“epoch”,      # Evaluate after each epoch
save_strategy=“epoch”,            # Save best model based on evaluation
load_best_model_at_end=True,      # Load best model after 
trainingmetric_for_best_model=“eval_loss”,       # Choose metric to decide best 
modelsave_total_limit=2,               # Keep only 2 best 
checkpointsreport_to=,logging_dir=”./finetuning/logs/"+args.output_dir, 
gradient_checkpointing=True,
bf16=True,
logging_strategy=“epoch”,
gradient_accumulation_steps=2,
greater_is_better=False,
    ) 

trainer = Seq2SeqTrainer(             
model=model,             
 args=training_args,              
train_dataset=train_dataset,             
 eval_dataset=val_dataset,             
tokenizer=tokenizer,             
 data_collator=collator,             
callbacks=[EarlyStoppingCallback(early_stopping_patience=15)]) 
trainer.train()

outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=220,
early_stopping = False,
num_beams=6,
num_beam_groups = 2,
repetition_penalty = 1.5,
length_penalty=0.8,
diversity_penalty = 0.5,
)

1 Like

It seems to have a slight quirk.

1 Like

Perfect. Thanks a lot for explaining so well. It worked. May God grant all your wishes.

1 Like