Fine-Tuning Pegasus - Model Not Training?

I’m trying to fine-tune Pegasus using a .csv with about 4,000 samples. The encodings are web text and the labels are abstract summaries. When I go to train the model (50 epochs, batch size of 16), it appears as though no training is taking place. Each iteration takes < 1 second and nearly ~30 seconds to iterate through 50 epochs… Not sure where I’m going wrong here, but would really appreciate some help/thoughts/suggestions.

I’ve been following the fine-tuning tutorial for the most part, which can be found here: https://huggingface.co/transformers/master/custom_datasets.html

Many thanks in advance!

My Code

from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from transformers import Trainer, TrainingArguments
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from torchvision import models
import pandas as pd
from torchvision import transforms, utils
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()


# Assign model & tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")


# Load data
data = pd.read_csv('C:/data.csv', sep=',', encoding='cp1252')
train_percentage = .8
test_percentage = 1-train_percentage
train_test_split_pct = int(len(data)*train_percentage)

train_summary = data.iloc[:train_test_split_pct,0].tolist()
train_webtext = data.iloc[:train_test_split_pct,1].tolist()
test_summary = data.iloc[train_test_split_pct:,0].tolist()
test_webtext = data.iloc[train_test_split_pct:,1].tolist()


# Tokenize our data
train_summary = tokenizer(train_summary, return_tensors="pt", truncation=True, padding=True)
train_webtext = tokenizer(train_webtext, return_tensors="pt",truncation=True, padding=True)
test_summary = tokenizer(test_summary, return_tensors="pt",truncation=True, padding=True)
test_webtext = tokenizer(test_webtext, return_tensors="pt",truncation=True, padding=True)


# Setup dataset objects
class Summary(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels['input_ids'][idx])  # torch.tensor(self.labels[idx])
        return item

    def __len__(self):
        return len(self.labels)


# Get datasets
train_dataset = Summary(train_webtext, train_summary)
test_dataset = Summary(test_webtext, test_summary)


# Train model
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=50,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    warmup_steps=500,
    weight_decay=0.01,
    #evaluate_during_training=True,
    logging_dir='./logs',
)

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

trainer.train()

Output

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:01<00:00,  1.12s/it]
Epoch:   2%|▏         | 1/50 [00:01<00:55,  1.12s/it]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.30it/s]
Epoch:   4%|▍         | 2/50 [00:01<00:48,  1.02s/it]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
Epoch:   6%|▌         | 3/50 [00:02<00:43,  1.09it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.59it/s]
Epoch:   8%|▊         | 4/50 [00:03<00:38,  1.20it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.62it/s]
Epoch:  10%|█         | 5/50 [00:03<00:34,  1.30it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]
Epoch:  12%|█▏        | 6/50 [00:04<00:31,  1.39it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  14%|█▍        | 7/50 [00:05<00:30,  1.43it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
Epoch:  16%|█▌        | 8/50 [00:05<00:28,  1.48it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  18%|█▊        | 9/50 [00:06<00:27,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.66it/s]
Epoch:  20%|██        | 10/50 [00:06<00:25,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  22%|██▏       | 11/50 [00:07<00:25,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  24%|██▍       | 12/50 [00:08<00:24,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  26%|██▌       | 13/50 [00:08<00:23,  1.57it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
Epoch:  28%|██▊       | 14/50 [00:09<00:22,  1.58it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  30%|███       | 15/50 [00:10<00:21,  1.59it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  32%|███▏      | 16/50 [00:10<00:21,  1.58it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  34%|███▍      | 17/50 [00:11<00:20,  1.59it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.61it/s]
Epoch:  36%|███▌      | 18/50 [00:11<00:20,  1.59it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Epoch:  38%|███▊      | 19/50 [00:12<00:19,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Epoch:  40%|████      | 20/50 [00:13<00:19,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Epoch:  42%|████▏     | 21/50 [00:13<00:18,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  44%|████▍     | 22/50 [00:14<00:18,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Epoch:  46%|████▌     | 23/50 [00:15<00:17,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Epoch:  48%|████▊     | 24/50 [00:15<00:16,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Epoch:  50%|█████     | 25/50 [00:16<00:16,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.36it/s]
Epoch:  52%|█████▏    | 26/50 [00:17<00:16,  1.48it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Epoch:  54%|█████▍    | 27/50 [00:17<00:15,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  56%|█████▌    | 28/50 [00:18<00:14,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Epoch:  58%|█████▊    | 29/50 [00:19<00:13,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.62it/s]
Epoch:  60%|██████    | 30/50 [00:19<00:12,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  62%|██████▏   | 31/50 [00:20<00:12,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.54it/s]
Epoch:  64%|██████▍   | 32/50 [00:21<00:11,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
Epoch:  66%|██████▌   | 33/50 [00:21<00:11,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
Epoch:  68%|██████▊   | 34/50 [00:22<00:10,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.55it/s]
Epoch:  70%|███████   | 35/50 [00:23<00:09,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
Epoch:  72%|███████▏  | 36/50 [00:23<00:09,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.69it/s]
Epoch:  74%|███████▍  | 37/50 [00:24<00:08,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.56it/s]
Epoch:  76%|███████▌  | 38/50 [00:25<00:07,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.60it/s]
Epoch:  78%|███████▊  | 39/50 [00:25<00:07,  1.57it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.44it/s]
Epoch:  80%|████████  | 40/50 [00:26<00:06,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.57it/s]
Epoch:  82%|████████▏ | 41/50 [00:26<00:05,  1.53it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
Epoch:  84%|████████▍ | 42/50 [00:27<00:05,  1.52it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.49it/s]
Epoch:  86%|████████▌ | 43/50 [00:28<00:04,  1.51it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.65it/s]
Epoch:  88%|████████▊ | 44/50 [00:28<00:03,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch:  90%|█████████ | 45/50 [00:29<00:03,  1.57it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.47it/s]
Epoch:  92%|█████████▏| 46/50 [00:30<00:02,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.64it/s]
Epoch:  94%|█████████▍| 47/50 [00:30<00:01,  1.56it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.51it/s]
Epoch:  96%|█████████▌| 48/50 [00:31<00:01,  1.54it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.58it/s]
Epoch:  98%|█████████▊| 49/50 [00:32<00:00,  1.55it/s]
Iteration:   0%|          | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00,  1.63it/s]
Epoch: 100%|██████████| 50/50 [00:32<00:00,  1.53it/s]
TrainOutput(global_step=50, training_loss=8.380059204101563)

I don’t see any error in the training code, I think you should manually try verifying the number of examples in the dataframe and the length of the dataset

Data looks good, I don’t see any issues. Tensors all have the same length, It is beyond me why I cannot get this model to train.