Fine-tuning Pegasus

Hi I’ve been using the Pegasus model over the past 2 weeks and have gotten some very good results. I would like to fine-tune the model further so that the performance is more tailored for my use-case.

I have some code up and running that uses Trainer. However, when looking at examples, the model does worse after training. In fact, the model output has a lot of repeating strings, the more the model is trained (i.e., more epochs). I’m wondering if my implementation is wrong, or if Trainer is not suitable for fine-tuning Pegasus (‘google/pegasus-xsum’). Am I running into catastrophic forgetting?

My code is not long, I’ve attached it below. I mostly used the tutorial(s) from:

import pandas as pd
in_df = pd.read_csv('/content/drive/My Drive/summaries_sample.csv')

# Train Test Split
train_pct = 0.6
test_pct = 0.2

in_df = in_df.sample(len(in_df), random_state=20)
train_sub = int(len(in_df) * train_pct)
test_sub = int(len(in_df) * test_pct) + train_sub

train_df = in_df[0:train_sub]
test_df = in_df[train_sub:test_sub]
val_df = in_df[test_sub:]

train_texts = list(train_df['allTextReprocess'])
test_texts = list(test_df['allTextReprocess'])
val_texts = list(val_df['allTextReprocess'])

train_decode = list(train_df['summaries'])
test_decode = list(test_df['summaries'])
val_decode = list(val_df['summaries'])

import transformers

import torch
min_length = 15
max_length = 40

# Setup model
model_name = 'google/pegasus-xsum'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
tokenizer = transformers.PegasusTokenizer.from_pretrained(model_name)

model = transformers.PegasusForConditionalGeneration.from_pretrained(model_name).to(torch_device)
in_text = [in_df['allTextReprocess'].iloc[3]]
batch = tokenizer.prepare_seq2seq_batch(in_text, truncation=True, padding='longest').to(torch_device) 

translated = model.generate(min_length=min_length, max_length=max_length, **batch)
tgt_text0 = tokenizer.batch_decode(translated, skip_special_tokens=True)

# Tokenize
train_encodings = tokenizer(train_texts, truncation=True, padding=True)
val_encodings = tokenizer(val_texts, truncation=True, padding=True)
test_encodings = tokenizer(test_texts, truncation=True, padding=True)

train_labels = tokenizer(train_decode, truncation=True, padding=True)
val_labels = tokenizer(val_decode, truncation=True, padding=True)
test_labels = tokenizer(test_decode, truncation=True, padding=True)

# Setup dataset objects
class Summary_dataset(
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])  # torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.encodings)

train_dataset = Summary_dataset(train_encodings, train_labels)
val_dataset = Summary_dataset(val_encodings, val_labels)
test_dataset = Summary_dataset(test_encodings, test_labels)

# Training
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=1000,              # total number of training epochs
    per_device_train_batch_size=16,  # batch size per device during training
    per_device_eval_batch_size=64,   # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset             # evaluation dataset


# Check results
in_text = [in_df['allTextReprocess'].iloc[3]]
batch = tokenizer.prepare_seq2seq_batch(in_text, truncation=True, padding='longest').to(torch_device) 

translated = model.generate(min_length=min_length, max_length=max_length, **batch)
tgt_text = tokenizer.batch_decode(translated, skip_special_tokens=True)

Any help would be awesome, thanks!

I also want to finetune Pegasus. Thank you for sharing your code! How similar is this to what happens in ?

Could try this with the examples/seq2seq scripts ?
also we have recently added Trainer support for seq2seq tasks as well.


Thanks for the response, and sorry for my delayed reply. Using Trainer is the big difference; abstracts a lot of the code away. It seems to be working now, though, so that’s good.

Thank you, and sorry for my delayed reply. My code seems to work, I think there were some bad examples in my input sequences for training. Removing those helped. After fine-tuning, I was able to get rid of a lot of cases where the model would give repeating text and randomly output something about the BBC.

I’m not sure if it is necessary, but do you know if there is a way to freeze layers using Trainer?

Thanks @DeathTruck. Would you be open to sharing your working Trainer code that I could use as a starting place, or is that the code you’ve already shared?

Yeah, so the code I pasted here should work. My problem initially was that I was feeding it some bad examples, which I believe was causing the problems. My best results have come with about 1000 training samples and 1000 epochs and lr=5E-5.

Let me know if you encounter any problems with the code.

1 Like

finetune_trainer script let’s you freeze embeddings layer and encoder using --freeze_embeds and --freeze_encoder arguments

Ok, thank you! I didn’t want to completely throw out my code that I posted here, but I wound up using the freezing code in examples/seq2seq/ to freeze either the embedding or encoder layers, before passing the model to Trainer. Seems to work. Thanks!

@DeathTruck Hi, Can you please share your code and tell me how did you freeze encoder layers? I am trying to do the same but can’t figure it out.

Hi @agenius5

You can pass --freeze_encoder flag to script to freeze all encoder layers.