hi everyone, I need some help/guidance in my current project about image captioning using ViT + GPT2 Base on Flickr8k …so basically I’m stuck at the point where I see others having great val loss and rouge scores using the same architecture but in my case the the loss is not so great neither the scores so I’m hoping anyone could point me the difficulties that are affecting my model learning…
Model
import torch
import torch.nn as nn
from transformers import (
ViTModel, ViTConfig,
GPT2Config, AutoModelForCausalLM,
VisionEncoderDecoderModel,
AutoImageProcessor, AutoTokenizer
)
# === VisionEncoderDecoder with ViT256 + GPT-2 === #
vit_model_name = "google/vit-base-patch16-224" # ViT with 256x256 support
decoder_model_name = "gpt2"
# Combine into VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('google/vit-base-patch16-224', 'gpt2')
# Image processor and tokenizer
processor = AutoImageProcessor.from_pretrained(vit_model_name,use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)
# Add special tokens if needed
if tokenizer.pad_token is None:
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
if tokenizer.bos_token is None or tokenizer.eos_token is None:
tokenizer.add_special_tokens({
'bos_token': '[CLS]',
'eos_token': '[SEP]'
})
# Resize decoder embeddings only
model.decoder.resize_token_embeddings(len(tokenizer))
# Set generation config
model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id
model.config.vocab_size = model.config.decoder.vocab_size
# Beam search config
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4
Dataset Class
from torch.utils.data import Dataset
import torch
class MultiCaptionImageDataset(Dataset):
def __init__(self, hf_dataset, processor, tokenizer, max_length):
self.dataset = hf_dataset
self.processor = processor
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
# Each image will generate 5 rows (one for each caption)
return len(self.dataset) * 5
def __getitem__(self, idx):
# For each row, determine which image-caption pair we are fetching
image_idx = idx // 5 # Integer division to get the image index
caption_idx = idx % 5 # Modulo to get which caption (0 to 4) for this row
item = self.dataset[image_idx]
image = item["image"]
if image.mode != "RGB":
image = image.convert("RGB")
# Process image
pixel_values = self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)
# Access the selected caption
caption_column = f"caption_{caption_idx}" # Dynamically select caption_0 to caption_4
caption = item[caption_column]
# Tokenize caption
tokenized = self.tokenizer(
caption,
padding="max_length",
truncation=True,
max_length=self.max_length,
return_tensors="pt"
)
input_ids = tokenized["input_ids"].squeeze(0)
attention_mask = tokenized["attention_mask"].squeeze(0)
# Mask loss on padding
labels = input_ids.clone()
labels[attention_mask == 0] = -100
return {
"pixel_values": pixel_values,
"labels": labels,
"attention_mask": attention_mask
}
from torch.utils.data import DataLoader
train_dataset = MultiCaptionImageDataset(ds['train'], processor, tokenizer, max_length)
val_dataset = MultiCaptionImageDataset(ds['validation'], processor, tokenizer, max_length)
test_dataset = MultiCaptionImageDataset(ds['test'], processor, tokenizer, max_length)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers = 2)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers = 2, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers = 2, shuffle = False)
Training Loop
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import time
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import nltk
nltk.download("punkt")
# Metric setup
smoothie = SmoothingFunction().method4
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL'], use_stemmer=True)
def compute_metrics(reference, hypothesis):
ref = [reference.split()]
hyp = hypothesis.split()
scores = {}
for i in range(1, 5):
weights = tuple([1 / i] * i + [0] * (4 - i))
scores[f"BLEU-{i}"] = sentence_bleu(ref, hyp, weights=weights, smoothing_function=smoothie)
rouge_scores = rouge.score(reference, hypothesis)
for key in rouge_scores:
scores[f"{key.upper()}"] = rouge_scores[key].fmeasure
return scores
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
model.to(device)
# Optimizer & Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode='min', factor=0.5, patience=1, threshold=0.03,
threshold_mode='rel', min_lr=1e-6, verbose=True,
)
# Mixed precision
scaler = GradScaler()
num_epochs = 5
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
epoch_start_time = time.time()
print(f"\n🧪 Epoch {epoch + 1}/{num_epochs}")
progress_bar = tqdm(train_loader, desc="Training", leave=False)
for step, batch in enumerate(progress_bar):
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
optimizer.zero_grad()
with autocast():
outputs = model(pixel_values=pixel_values, labels=labels)
loss = outputs.loss
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
avg_loss = running_loss / (step + 1)
progress_bar.set_postfix(loss=f"{avg_loss:.4f}")
train_loss = running_loss / len(train_loader)
print(f"✅ Train Loss: {train_loss:.4f} | Time: {(time.time() - epoch_start_time) / 60:.2f} min")
# Validation
model.eval()
val_loss = 0.0
all_metrics = []
with torch.no_grad():
for batch in tqdm(val_loader, desc="Validation", leave=False):
pixel_values = batch["pixel_values"].to(device)
labels = batch["labels"].to(device)
with autocast():
outputs = model(pixel_values=pixel_values, labels=labels)
val_loss += outputs.loss.item()
# Generate captions
gen_ids = model.generate(pixel_values=pixel_values, max_new_tokens=50)
preds = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
# Clean label sequences by removing -100s
cleaned_labels = []
for label_seq in labels:
cleaned_seq = [token_id.item() for token_id in label_seq if token_id.item() != -100]
cleaned_labels.append(cleaned_seq)
refs = tokenizer.batch_decode(cleaned_labels, skip_special_tokens=True)
# Compute metrics
for ref, pred in zip(refs, preds):
metrics = compute_metrics(ref.strip(), pred.strip())
all_metrics.append(metrics)
avg_val_loss = val_loss / len(val_loader)
print(f"🔍 Validation Loss: {avg_val_loss:.4f}")
# Aggregate metrics
avg_metrics = {k: sum(m[k] for m in all_metrics) / len(all_metrics) for k in all_metrics[0]}
for name, value in avg_metrics.items():
print(f"{name}: {value:.4f}")
# LR Scheduler
prev_lr = optimizer.param_groups[0]['lr']
scheduler.step(avg_val_loss)
new_lr = optimizer.param_groups[0]['lr']
if new_lr < prev_lr:
print(f"📉 LR reduced from {prev_lr:.6f} to {new_lr:.6f}")
else:
print(f"⚠️ LR not reduced. Still at {new_lr:.6f}")
print(f"🧠 Best val loss so far (tracked by scheduler): {scheduler.best:.6f}")