Image Captioning with ViT and GPT 2 Base

hi everyone, I need some help/guidance in my current project about image captioning using ViT + GPT2 Base on Flickr8k …so basically I’m stuck at the point where I see others having great val loss and rouge scores using the same architecture but in my case the the loss is not so great neither the scores so I’m hoping anyone could point me the difficulties that are affecting my model learning…

Model

import torch
import torch.nn as nn
from transformers import (
    ViTModel, ViTConfig,
    GPT2Config, AutoModelForCausalLM,
    VisionEncoderDecoderModel,
    AutoImageProcessor, AutoTokenizer
)

# === VisionEncoderDecoder with ViT256 + GPT-2 === #
vit_model_name = "google/vit-base-patch16-224"  # ViT with 256x256 support
decoder_model_name = "gpt2"

# Combine into VisionEncoderDecoderModel
model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained('google/vit-base-patch16-224', 'gpt2')

# Image processor and tokenizer
processor = AutoImageProcessor.from_pretrained(vit_model_name,use_fast=True)
tokenizer = AutoTokenizer.from_pretrained(decoder_model_name)

# Add special tokens if needed
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '[PAD]'})
if tokenizer.bos_token is None or tokenizer.eos_token is None:
    tokenizer.add_special_tokens({
        'bos_token': '[CLS]',
        'eos_token': '[SEP]'
    })

# Resize decoder embeddings only
model.decoder.resize_token_embeddings(len(tokenizer))

# Set generation config
model.config.decoder_start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.config.eos_token_id = tokenizer.eos_token_id or tokenizer.sep_token_id
model.config.vocab_size = model.config.decoder.vocab_size

# Beam search config
model.config.max_length = 128
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4


Dataset Class

from torch.utils.data import Dataset
import torch

class MultiCaptionImageDataset(Dataset):
    def __init__(self, hf_dataset, processor, tokenizer, max_length):
        self.dataset = hf_dataset
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        # Each image will generate 5 rows (one for each caption)
        return len(self.dataset) * 5

    def __getitem__(self, idx):
        # For each row, determine which image-caption pair we are fetching
        image_idx = idx // 5  # Integer division to get the image index
        caption_idx = idx % 5  # Modulo to get which caption (0 to 4) for this row

        item = self.dataset[image_idx]

        image = item["image"]
        if image.mode != "RGB":
            image = image.convert("RGB")

        # Process image
        pixel_values = self.processor(images=image, return_tensors="pt")["pixel_values"].squeeze(0)

        # Access the selected caption
        caption_column = f"caption_{caption_idx}"  # Dynamically select caption_0 to caption_4
        caption = item[caption_column]

        # Tokenize caption
        tokenized = self.tokenizer(
            caption,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )

        input_ids = tokenized["input_ids"].squeeze(0)
        attention_mask = tokenized["attention_mask"].squeeze(0)

        # Mask loss on padding
        labels = input_ids.clone()
        labels[attention_mask == 0] = -100

        return {
            "pixel_values": pixel_values,
            "labels": labels,
            "attention_mask": attention_mask
        }
from torch.utils.data import DataLoader


train_dataset = MultiCaptionImageDataset(ds['train'], processor, tokenizer, max_length)
val_dataset = MultiCaptionImageDataset(ds['validation'], processor, tokenizer, max_length)
test_dataset = MultiCaptionImageDataset(ds['test'], processor, tokenizer, max_length)


train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers = 2)
val_loader = DataLoader(val_dataset, batch_size=32, num_workers = 2, shuffle = False)
test_loader = DataLoader(test_dataset, batch_size=32, num_workers = 2, shuffle = False)

Training Loop

import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import time
from tqdm import tqdm
from torch.cuda.amp import autocast, GradScaler

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
import nltk
nltk.download("punkt")

# Metric setup
smoothie = SmoothingFunction().method4
rouge = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rouge3', 'rouge4', 'rougeL'], use_stemmer=True)

def compute_metrics(reference, hypothesis):
    ref = [reference.split()]
    hyp = hypothesis.split()
    scores = {}

    for i in range(1, 5):
        weights = tuple([1 / i] * i + [0] * (4 - i))
        scores[f"BLEU-{i}"] = sentence_bleu(ref, hyp, weights=weights, smoothing_function=smoothie)

    rouge_scores = rouge.score(reference, hypothesis)
    for key in rouge_scores:
        scores[f"{key.upper()}"] = rouge_scores[key].fmeasure

    return scores

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}, GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
model.to(device)

# Optimizer & Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=1, threshold=0.03,
    threshold_mode='rel', min_lr=1e-6, verbose=True,
)

# Mixed precision
scaler = GradScaler()
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    epoch_start_time = time.time()
    print(f"\n🧪 Epoch {epoch + 1}/{num_epochs}")
    progress_bar = tqdm(train_loader, desc="Training", leave=False)

    for step, batch in enumerate(progress_bar):
        pixel_values = batch["pixel_values"].to(device)
        labels = batch["labels"].to(device)
        

        optimizer.zero_grad()
        with autocast():
            outputs = model(pixel_values=pixel_values, labels=labels)
            loss = outputs.loss

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        avg_loss = running_loss / (step + 1)
        progress_bar.set_postfix(loss=f"{avg_loss:.4f}")

    train_loss = running_loss / len(train_loader)
    print(f"✅ Train Loss: {train_loss:.4f} | Time: {(time.time() - epoch_start_time) / 60:.2f} min")

    # Validation
    model.eval()
    val_loss = 0.0
    all_metrics = []
    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validation", leave=False):
            pixel_values = batch["pixel_values"].to(device)
            labels = batch["labels"].to(device)
            

            with autocast():
                outputs = model(pixel_values=pixel_values, labels=labels)
                val_loss += outputs.loss.item()

                # Generate captions
                gen_ids = model.generate(pixel_values=pixel_values, max_new_tokens=50)
                preds = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

                # Clean label sequences by removing -100s
                cleaned_labels = []
                for label_seq in labels:
                    cleaned_seq = [token_id.item() for token_id in label_seq if token_id.item() != -100]
                    cleaned_labels.append(cleaned_seq)

                refs = tokenizer.batch_decode(cleaned_labels, skip_special_tokens=True)

                # Compute metrics
                for ref, pred in zip(refs, preds):
                    metrics = compute_metrics(ref.strip(), pred.strip())
                    all_metrics.append(metrics)

    avg_val_loss = val_loss / len(val_loader)
    print(f"🔍 Validation Loss: {avg_val_loss:.4f}")

    # Aggregate metrics
    avg_metrics = {k: sum(m[k] for m in all_metrics) / len(all_metrics) for k in all_metrics[0]}
    for name, value in avg_metrics.items():
        print(f"{name}: {value:.4f}")

    # LR Scheduler
    prev_lr = optimizer.param_groups[0]['lr']
    scheduler.step(avg_val_loss)
    new_lr = optimizer.param_groups[0]['lr']

    if new_lr < prev_lr:
        print(f"📉 LR reduced from {prev_lr:.6f} to {new_lr:.6f}")
    else:
        print(f"⚠️ LR not reduced. Still at {new_lr:.6f}")
        print(f"🧠 Best val loss so far (tracked by scheduler): {scheduler.best:.6f}")

1 Like

Hmm… At first glance, there doesn’t seem to be any problem.

Perhaps the library version has changed so much that the behavior of the functions or parameters described in the article is different from the current behavior. Discontinuations and specification changes sometimes occur over a period of six months or more.

so if I try out more advanced models then the performance may improve because as far as my knowledge goes there is not any fundamental problem in the training process nor with the model itself …

1 Like