Solution for Fine Tuning the Blip Model

Hello Everyone,

I am working on Imnage Captioning Project using Blip model. Flicker30 Dataset i am using and trying BLIP-large to fine tune on it, however, some error is coming. I am attaching the code. Kindly help me to resolve it.

import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
from PIL import Image
from model import model, processor # Importing the model and processor from model.py
from dataset import load_captions_from_txt
from preprocess import preprocess_image, preprocess_caption, preprocess_dataset
from torch.nn.utils.rnn import pad_sequence

class CaptioningDataset(Dataset):
def init(self, image_dir, captions_dict, processor):
self.image_dir = image_dir
self.captions_dict = captions_dict
self.processor = processor
self.image_filenames = list(captions_dict.keys())

def __len__(self):
    return len(self.image_filenames)

def __getitem__(self, idx):
    # Get the image filename and image path
    image_filename = self.image_filenames[idx]
    image_path = os.path.join(self.image_dir, image_filename)

    # Open the image using PIL (make sure it's in RGB format)
    image = Image.open(image_path).convert("RGB")

    # Use the processor to handle image preprocessing
    image_input = self.processor(images=image, return_tensors="pt")

    # Get captions from the dictionary
    captions = self.captions_dict[image_filename]

    # Process captions using the tokenizer (assuming it's a separate tokenizer for text)
    caption_inputs = [self.processor.tokenizer(caption, return_tensors="pt", padding=True, truncation=True) for caption in captions]

    return image_input, caption_inputs

def collate_fn(batch):
# batch is a list of tuples (image_inputs, caption_inputs)

# Separate images and captions
image_inputs = [item[0] for item in batch]
caption_inputs = [item[1] for item in batch]

# Use the tokenizer's padding and truncation feature to ensure all captions have the same size
# processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

# Extract input_ids for each caption and apply padding and truncation
caption_input_ids = [processor(caption['caption'], padding="max_length", truncation=True, return_tensors="pt")['input_ids'] for caption in caption_inputs]

# Stack all caption input_ids into a single tensor (ensure they're padded to the same length)
caption_input_ids = torch.cat(caption_input_ids, dim=0)

# Return image inputs and the padded caption input_ids
return image_inputs, caption_input_ids

Hyperparameters

batch_size = 8
learning_rate = 5e-5
num_epochs = 3

Load captions

captions_file = ‘/Users/arpitsharma/cvision/captions.txt’ # Path to your captions file
image_dir = ‘/Users/arpitsharma/cvision/Images/flickr30k_images’ # Path to your images

captions_dict = load_captions_from_txt(captions_file)

Preprocess the dataset

preprocessed_data = preprocess_dataset(image_dir, captions_dict, processor)

Create DataLoader

dataset = CaptioningDataset(image_dir, captions_dict, processor)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)

Initialize the optimizer

optimizer = AdamW(model.parameters(), lr=learning_rate)

Training loop

device = torch.device(“mps” if torch.backends.mps.is_available() else “cpu”) # Use MPS for Apple Silicon

model.to(device)

for epoch in range(num_epochs):
model.train() # Set model to training mode
total_loss = 0

for batch_idx, (image_inputs, caption_inputs) in enumerate(train_dataloader):
    image_inputs = {key: value.to(device) for key, value in image_inputs.items()}  # Move inputs to device
    input_ids = caption_inputs['input_ids'].to(device)
    attention_mask = caption_inputs['attention_mask'].to(device)

    # Forward pass
    outputs = model(input_ids=input_ids, attention_mask=attention_mask, images=image_inputs['pixel_values'], labels=input_ids)
    loss = outputs.loss

    # Backpropagation
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

    if batch_idx % 10 == 0:  # Print loss every 10 batches
        print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_dataloader)}], Loss: {loss.item()}")

print(f"Epoch [{epoch+1}/{num_epochs}] completed. Total loss: {total_loss}")

Save the fine-tuned model

output_dir = ‘/Users/arpitsharma/cvision’
if not os.path.exists(output_dir):
os.makedirs(output_dir)

model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)

print(f"Fine-tuned model saved to {output_dir}")

1 Like