Need Help for News Summary fine tuning using flan t5

Hey everyone…I was doing a personal project on News Summarization using Flan T5 base model on gopalkalpande/bbc-news-summary hf dataset…so I’m using 95th percentile of the token length that covers 95 percent of my input and outputs … below is my training setup

from transformers import Seq2SeqTrainer,Seq2SeqTrainingArguments
from transformers.data.data_collator import default_data_collator
import time
from transformers import DataCollatorForSeq2Seq
from transformers import  get_scheduler
from torch.optim import AdamW

data_collator = DataCollatorForSeq2Seq(
    tokenizer=tokenizer,
    model=model,
    label_pad_token_id=-100  
)
output_dir = f'./news-sum-training-{str(int(time.time()))}'
train_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    num_train_epochs=10,
    eval_strategy="epoch",
    auto_find_batch_size=True,
    evaluation_strategy="epoch",
    learning_rate=1e-4,
    weight_decay=0.01,
    logging_steps=10,
    fp16=False,
    predict_with_generate=True,           
    save_strategy="epoch",
    load_best_model_at_end=True           
)

trainer=Seq2SeqTrainer(
    model=peft_model,
    args=train_args,
    train_dataset=tokenized_train_ds,
    eval_dataset=tokenized_val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

and this is my losses and rouge scores …

and I’m doing LoRA fine tuning so that config is as below

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)

peft_model = get_peft_model(model, lora_config)
peft_model.print_trainable_parameters()

now i need help to decrease the validation loss further and increase rouge scores to a acceptable level…pls anyone help me or guide what should I do in this situation…thanks in advance.

1 Like

I think it’s better to use the hyperparameter search for serious hyperparameter tuning, but there seem to be a few settings that can be tweaked before that.

Also, AdamW has been deprecated and officially discontinued a while ago, so I think it’s better to use a different optimizer.


following is by Hugging Chat.

To improve the performance of your news summarization model using Flan-T5, here are the key steps and considerations based on your thought process:

  1. Data Preprocessing and Tokenization:

    • Add Task Prefix: Consider prefixing your input texts with “summarize:” or a similar instruction to help the model understand the task clearly.
    • Adjust Token Length: Evaluate if the current max_length is too restrictive. Increasing it slightly might allow the model to capture more context, especially in longer articles.
  2. Model Configuration:

    • Experiment with LoRA Rank: Try increasing the LoRA rank from 32 to perhaps 64 or 84. This could enhance the model’s ability to capture information, but monitor for overfitting.
    • Learning Rate and Scheduler: Adjust the learning rate (e.g., try 1e-3) and consider implementing a learning rate scheduler to optimize the training process.
  3. Training Parameters:

    • Increase Epochs: Extend the number of training epochs to 15 or more to allow the model more time to learn, ensuring thorough monitoring to prevent overfitting.
    • Data Collator Settings: Review and possibly adjust settings within DataCollatorForSeq2Seq to enhance data handling and processing efficiency.
  4. Regularization Techniques:

    • Dropout Adjustment: Increase the dropout parameter in LoRA from 0.05 to 0.1 to add more regularization.
    • Gradient Clipping: Implement gradient clipping to prevent exploding gradients and stabilize training.
  5. Training Strategy:

    • Gradient Accumulation: Use gradient accumulation steps to effectively increase the batch size without exceeding hardware limits.
    • Early Stopping: Incorporate early stopping to halt training when improvement stalls, preventing overfitting.
  6. Evaluation and Metrics:

    • Expand Metrics: While focusing on ROUGE scores, consider incorporating additional evaluation metrics for a comprehensive performance analysis.
  7. Hyperparameter Tuning:

    • Grid Search: Conduct a grid search to identify optimal hyperparameters, including learning rate, weight decay, and LoRA configurations.

By systematically addressing these areas, you can enhance your model’s performance, potentially leading to lower validation loss and improved ROUGE scores. Monitoring each change closely will help determine the most effective adjustments for your specific dataset and task.

def preprocess_news_example(example):
    instruction = (
        "### Instruction:\n"
        "Summarize the following news article.\n\n"
        "### Input:\n"
        f"{example['Articles']}\n\n"
        "### Response:"
    )
    return {
        "instruction": instruction,
        "output": example["Summaries"]
    }

this is my preprocesing fn I’m already using a prefix for instruction and in the previous snippets I just imported AdamW not used it actually…could you pls point out the data collator settings like what could I possibly modify there

1 Like