I’m trying to fine-tune Pegasus using a .csv
with about 4,000 samples. The encodings are web text and the labels are abstract summaries. When I go to train the model (50 epochs, batch size of 16), it appears as though no training is taking place. Each iteration takes < 1 second and nearly ~30 seconds to iterate through 50 epochs… Not sure where I’m going wrong here, but would really appreciate some help/thoughts/suggestions.
I’ve been following the fine-tuning tutorial for the most part, which can be found here: https://huggingface.co/transformers/master/custom_datasets.html
Many thanks in advance!
My Code
from transformers import PegasusForConditionalGeneration, PegasusTokenizer
from transformers import Trainer, TrainingArguments
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AdamW
from torchvision import models
import pandas as pd
from torchvision import transforms, utils
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.cuda.empty_cache()
# Assign model & tokenizer
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distill-pegasus-xsum-16-8")
# Load data
data = pd.read_csv('C:/data.csv', sep=',', encoding='cp1252')
train_percentage = .8
test_percentage = 1-train_percentage
train_test_split_pct = int(len(data)*train_percentage)
train_summary = data.iloc[:train_test_split_pct,0].tolist()
train_webtext = data.iloc[:train_test_split_pct,1].tolist()
test_summary = data.iloc[train_test_split_pct:,0].tolist()
test_webtext = data.iloc[train_test_split_pct:,1].tolist()
# Tokenize our data
train_summary = tokenizer(train_summary, return_tensors="pt", truncation=True, padding=True)
train_webtext = tokenizer(train_webtext, return_tensors="pt",truncation=True, padding=True)
test_summary = tokenizer(test_summary, return_tensors="pt",truncation=True, padding=True)
test_webtext = tokenizer(test_webtext, return_tensors="pt",truncation=True, padding=True)
# Setup dataset objects
class Summary(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels['input_ids'][idx]) # torch.tensor(self.labels[idx])
return item
def __len__(self):
return len(self.labels)
# Get datasets
train_dataset = Summary(train_webtext, train_summary)
test_dataset = Summary(test_webtext, test_summary)
# Train model
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
'f1': f1,
'precision': precision,
'recall': recall
}
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=50,
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
#evaluate_during_training=True,
logging_dir='./logs',
)
trainer = Trainer(
model=model,
args=training_args,
compute_metrics=compute_metrics,
train_dataset=train_dataset,
eval_dataset=test_dataset
)
trainer.train()
Output
Epoch: 0%| | 0/50 [00:00<?, ?it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:01<00:00, 1.12s/it]
Epoch: 2%|▏ | 1/50 [00:01<00:55, 1.12s/it]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.30it/s]
Epoch: 4%|▍ | 2/50 [00:01<00:48, 1.02s/it]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.48it/s]
Epoch: 6%|▌ | 3/50 [00:02<00:43, 1.09it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.59it/s]
Epoch: 8%|▊ | 4/50 [00:03<00:38, 1.20it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.62it/s]
Epoch: 10%|█ | 5/50 [00:03<00:34, 1.30it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.66it/s]
Epoch: 12%|█▏ | 6/50 [00:04<00:31, 1.39it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.55it/s]
Epoch: 14%|█▍ | 7/50 [00:05<00:30, 1.43it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]
Epoch: 16%|█▌ | 8/50 [00:05<00:28, 1.48it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.61it/s]
Epoch: 18%|█▊ | 9/50 [00:06<00:27, 1.51it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.66it/s]
Epoch: 20%|██ | 10/50 [00:06<00:25, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.55it/s]
Epoch: 22%|██▏ | 11/50 [00:07<00:25, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.61it/s]
Epoch: 24%|██▍ | 12/50 [00:08<00:24, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.61it/s]
Epoch: 26%|██▌ | 13/50 [00:08<00:23, 1.57it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]
Epoch: 28%|██▊ | 14/50 [00:09<00:22, 1.58it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.63it/s]
Epoch: 30%|███ | 15/50 [00:10<00:21, 1.59it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.55it/s]
Epoch: 32%|███▏ | 16/50 [00:10<00:21, 1.58it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.63it/s]
Epoch: 34%|███▍ | 17/50 [00:11<00:20, 1.59it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.61it/s]
Epoch: 36%|███▌ | 18/50 [00:11<00:20, 1.59it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.50it/s]
Epoch: 38%|███▊ | 19/50 [00:12<00:19, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.57it/s]
Epoch: 40%|████ | 20/50 [00:13<00:19, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.54it/s]
Epoch: 42%|████▏ | 21/50 [00:13<00:18, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.55it/s]
Epoch: 44%|████▍ | 22/50 [00:14<00:18, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.58it/s]
Epoch: 46%|████▌ | 23/50 [00:15<00:17, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.57it/s]
Epoch: 48%|████▊ | 24/50 [00:15<00:16, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.50it/s]
Epoch: 50%|█████ | 25/50 [00:16<00:16, 1.54it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.36it/s]
Epoch: 52%|█████▏ | 26/50 [00:17<00:16, 1.48it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.58it/s]
Epoch: 54%|█████▍ | 27/50 [00:17<00:15, 1.51it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.63it/s]
Epoch: 56%|█████▌ | 28/50 [00:18<00:14, 1.54it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.54it/s]
Epoch: 58%|█████▊ | 29/50 [00:19<00:13, 1.54it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.62it/s]
Epoch: 60%|██████ | 30/50 [00:19<00:12, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.55it/s]
Epoch: 62%|██████▏ | 31/50 [00:20<00:12, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.54it/s]
Epoch: 64%|██████▍ | 32/50 [00:21<00:11, 1.54it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.44it/s]
Epoch: 66%|██████▌ | 33/50 [00:21<00:11, 1.51it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.56it/s]
Epoch: 68%|██████▊ | 34/50 [00:22<00:10, 1.52it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.55it/s]
Epoch: 70%|███████ | 35/50 [00:23<00:09, 1.52it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.48it/s]
Epoch: 72%|███████▏ | 36/50 [00:23<00:09, 1.51it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.69it/s]
Epoch: 74%|███████▍ | 37/50 [00:24<00:08, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.56it/s]
Epoch: 76%|███████▌ | 38/50 [00:25<00:07, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.60it/s]
Epoch: 78%|███████▊ | 39/50 [00:25<00:07, 1.57it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.44it/s]
Epoch: 80%|████████ | 40/50 [00:26<00:06, 1.52it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.57it/s]
Epoch: 82%|████████▏ | 41/50 [00:26<00:05, 1.53it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.50it/s]
Epoch: 84%|████████▍ | 42/50 [00:27<00:05, 1.52it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.49it/s]
Epoch: 86%|████████▌ | 43/50 [00:28<00:04, 1.51it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.65it/s]
Epoch: 88%|████████▊ | 44/50 [00:28<00:03, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.63it/s]
Epoch: 90%|█████████ | 45/50 [00:29<00:03, 1.57it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.47it/s]
Epoch: 92%|█████████▏| 46/50 [00:30<00:02, 1.54it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.64it/s]
Epoch: 94%|█████████▍| 47/50 [00:30<00:01, 1.56it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.51it/s]
Epoch: 96%|█████████▌| 48/50 [00:31<00:01, 1.54it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.58it/s]
Epoch: 98%|█████████▊| 49/50 [00:32<00:00, 1.55it/s]
Iteration: 0%| | 0/1 [00:00<?, ?it/s]A
Iteration: 100%|██████████| 1/1 [00:00<00:00, 1.63it/s]
Epoch: 100%|██████████| 50/50 [00:32<00:00, 1.53it/s]
TrainOutput(global_step=50, training_loss=8.380059204101563)