Fine-Tuning Pegasus - Model Not Training?

I’m trying to fine-tune Pegasus using a .csv with about 4,000 samples. The encodings are web text and the labels are abstract summaries. When I go to train the model (50 epochs, batch size of 16), it appears as though no training is taking place. Each iteration takes < 1 second and nearly ~30 seconds to iterate through 50 epochs… Not sure where I’m going wrong here, but would really appreciate some help/thoughts/suggestions.

I’ve been following the fine-tuning tutorial for the most part, which can be found here: https://huggingface.co/transformers/master/custom_datasets.html

Many thanks in advance!

My Code

from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from transformers import Trainer, TrainingArguments
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from torchvision import models
import pandas as pd
from torchvision import transforms, utils
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()


# Assign model & tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")


# Load data
data = pd.read_csv('C:/data.csv', sep=',', encoding='cp1252')
train_percentage = .8
test_percentage = 1-train_percentage
train_test_split_pct = int(len(data)*train_percentage)

train_summary = data.iloc[:train_test_split_pct,0].tolist()
train_webtext = data.iloc[:train_test_split_pct,1].tolist()
test_summary = data.iloc[train_test_split_pct:,0].tolist()
test_webtext = data.iloc[train_test_split_pct:,1].tolist()


# Tokenize our data
train_summary = tokenizer(train_summary, return_tensors="pt", truncation=True, padding=True)
train_webtext = tokenizer(train_webtext, return_tensors="pt",truncation=True, padding=True)
test_summary = tokenizer(test_summary, return_tensors="pt",truncation=True, padding=True)
test_webtext = tokenizer(test_webtext, return_tensors="pt",truncation=True, padding=True)


# Setup dataset objects
class Summary(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])  # torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


# Get datasets
train_dataset = Summary(train_webtext, train_summary)
test_dataset = Summary(test_webtext, test_summary)


# Train model
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=50,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    #evaluate_during_training=True,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

trainer.train()

Output

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:01<00:00,  1.12s/it]
Epoch:   2%|▏         | 1/50 [00:01<00:55,  1.12s/it]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s]
Epoch:   4%|▍         | 2/50 [00:01<00:48,  1.02s/it]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
Epoch:   6%|▌         | 3/50 [00:02<00:43,  1.09it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.59it/s]
Epoch:   8%|▊         | 4/50 [00:03<00:38,  1.20it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.62it/s]
Epoch:  10%|█         | 5/50 [00:03<00:34,  1.30it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]
Epoch:  12%|█▏        | 6/50 [00:04<00:31,  1.39it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  14%|█▍        | 7/50 [00:05<00:30,  1.43it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
Epoch:  16%|█▌        | 8/50 [00:05<00:28,  1.48it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  18%|█▊        | 9/50 [00:06<00:27,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]
Epoch:  20%|██        | 10/50 [00:06<00:25,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  22%|██▏       | 11/50 [00:07<00:25,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  24%|██▍       | 12/50 [00:08<00:24,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  26%|██▌       | 13/50 [00:08<00:23,  1.57it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
Epoch:  28%|██▊       | 14/50 [00:09<00:22,  1.58it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  30%|███       | 15/50 [00:10<00:21,  1.59it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  32%|███▏      | 16/50 [00:10<00:21,  1.58it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  34%|███▍      | 17/50 [00:11<00:20,  1.59it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  36%|███▌      | 18/50 [00:11<00:20,  1.59it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Epoch:  38%|███▊      | 19/50 [00:12<00:19,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Epoch:  40%|████      | 20/50 [00:13<00:19,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Epoch:  42%|████▏     | 21/50 [00:13<00:18,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  44%|████▍     | 22/50 [00:14<00:18,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Epoch:  46%|████▌     | 23/50 [00:15<00:17,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Epoch:  48%|████▊     | 24/50 [00:15<00:16,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Epoch:  50%|█████     | 25/50 [00:16<00:16,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.36it/s]
Epoch:  52%|█████▏    | 26/50 [00:17<00:16,  1.48it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Epoch:  54%|█████▍    | 27/50 [00:17<00:15,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  56%|█████▌    | 28/50 [00:18<00:14,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Epoch:  58%|█████▊    | 29/50 [00:19<00:13,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.62it/s]
Epoch:  60%|██████    | 30/50 [00:19<00:12,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  62%|██████▏   | 31/50 [00:20<00:12,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Epoch:  64%|██████▍   | 32/50 [00:21<00:11,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
Epoch:  66%|██████▌   | 33/50 [00:21<00:11,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
Epoch:  68%|██████▊   | 34/50 [00:22<00:10,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  70%|███████   | 35/50 [00:23<00:09,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
Epoch:  72%|███████▏  | 36/50 [00:23<00:09,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.69it/s]
Epoch:  74%|███████▍  | 37/50 [00:24<00:08,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
Epoch:  76%|███████▌  | 38/50 [00:25<00:07,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
Epoch:  78%|███████▊  | 39/50 [00:25<00:07,  1.57it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
Epoch:  80%|████████  | 40/50 [00:26<00:06,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Epoch:  82%|████████▏ | 41/50 [00:26<00:05,  1.53it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Epoch:  84%|████████▍ | 42/50 [00:27<00:05,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
Epoch:  86%|████████▌ | 43/50 [00:28<00:04,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s]
Epoch:  88%|████████▊ | 44/50 [00:28<00:03,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  90%|█████████ | 45/50 [00:29<00:03,  1.57it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
Epoch:  92%|█████████▏| 46/50 [00:30<00:02,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.64it/s]
Epoch:  94%|█████████▍| 47/50 [00:30<00:01,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
Epoch:  96%|█████████▌| 48/50 [00:31<00:01,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Epoch:  98%|█████████▊| 49/50 [00:32<00:00,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]
TrainOutput(global_step=50, training_loss=8.380059204101563)

I don’t see any error in the training code, I think you should manually try verifying the number of examples in the dataframe and the length of the dataset

Data looks good, I don’t see any issues. Tensors all have the same length, It is beyond me why I cannot get this model to train.

Is it because you didn’t set the learning rate? If you can get it to work please let us know, I will use this script myself.

This works in Colab.
I only ran it for 3 epochs, on a small portion of a dataset, but it’s training and the metrics improved.
The main changes are:

  • passing in the decoder_input_ids parameter (no right shifting in this example)
  • handling the -100 pad_token_id value for the labels
  • using Adafactor instead of AdamW
  • using a label_smoothing_factor
  • setting the learning_rate to one of the values from the paper (appendix C)

I hope it helps.

! pip install transformers
! pip install datasets
! pip install sentencepiece
! pip install rouge_score

import torch
import datasets

from transformers import (
    AutoModelForSeq2SeqLM, AutoTokenizer,
    Seq2SeqTrainingArguments, Seq2SeqTrainer
  )

model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")

# Use wiki_lingua for demonstration
language = "english"
data = datasets.load_dataset("wiki_lingua", name=language, split="train[:1%]")
source = [x["document"] for x in data["article"] if len(x["document"]) > 0]
target = [x["summary"] for x in data["article"] if len(x["document"]) > 0]
source = [item for sublist in source for item in sublist]
target = [item for sublist in target for item in sublist]
N_DEMO = 100  # Try on a small portion of the data
source = source[:N_DEMO]
target = target[:N_DEMO]

train_percentage = .8
test_percentage = 1-train_percentage
train_test_split_pct = int(len(source)*train_percentage)

train_summary = target[:train_test_split_pct]
train_webtext = source[:train_test_split_pct]
test_summary = target[train_test_split_pct:]
test_webtext = source[train_test_split_pct:]

encoder_max_length=512
decoder_max_length=64

# Tokenize our data
train_summary = tokenizer(train_summary, padding="max_length", 
                          truncation=True, max_length=decoder_max_length)
train_webtext = tokenizer(train_webtext, padding="max_length", 
                          truncation=True, max_length=encoder_max_length)
test_summary = tokenizer(test_summary, padding="max_length", 
                         truncation=True, max_length=decoder_max_length)
test_webtext = tokenizer(test_webtext, padding="max_length", 
                         truncation=True, max_length=encoder_max_length)

# Setup dataset objects
class Summary(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        labels = self.labels['input_ids'][idx]
        labels100 = [-100 if token == tokenizer.pad_token_id 
                     else token for token in labels]
        item['labels'] = torch.tensor(labels100)
        item['decoder_input_ids'] = torch.tensor(labels)            
        return item

    def __len__(self):
        return len(self.labels)


# Get datasets
train_dataset = Summary(train_webtext, train_summary)
test_dataset = Summary(test_webtext, test_summary)

rouge = datasets.load_metric("rouge")

def compute_metrics(pred):
    """
    Borrowed from https://github.com/patrickvonplaten/notebooks/blob/master/RoBERTaShared_for_BBC_XSum.ipynb
    """
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    # all unnecessary tokens are removed
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = tokenizer.pad_token_id
    label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True)

    rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=["rouge2"])["rouge2"].mid

    return {
        "rouge2_precision": round(rouge_output.precision, 4),
        "rouge2_recall": round(rouge_output.recall, 4),
        "rouge2_fmeasure": round(rouge_output.fmeasure, 4),
    }

batch_size = 4  # demo

training_args = Seq2SeqTrainingArguments( 
    output_dir='./results',
    adafactor=True,     
    num_train_epochs=3, # 50, # demo
    per_device_train_batch_size=batch_size, 
    per_device_eval_batch_size=batch_size,
    predict_with_generate=True,
    do_train=True,
    do_eval=True,
    learning_rate=5e-4,
    label_smoothing_factor=0.1,
    overwrite_output_dir=True, 
    logging_dir='./logs'
)

trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

trainer.evaluate()

trainer.train()

trainer.evaluate()