Hello Everyone,
I am working on Imnage Captioning Project using Blip model. Flicker30 Dataset i am using and trying BLIP-large to fine tune on it, however, some error is coming. I am attaching the code. Kindly help me to resolve it.
import os
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import AdamW
from PIL import Image
from model import model, processor # Importing the model and processor from model.py
from dataset import load_captions_from_txt
from preprocess import preprocess_image, preprocess_caption, preprocess_dataset
from torch.nn.utils.rnn import pad_sequence
class CaptioningDataset(Dataset):
def init(self, image_dir, captions_dict, processor):
self.image_dir = image_dir
self.captions_dict = captions_dict
self.processor = processor
self.image_filenames = list(captions_dict.keys())
def __len__(self):
return len(self.image_filenames)
def __getitem__(self, idx):
# Get the image filename and image path
image_filename = self.image_filenames[idx]
image_path = os.path.join(self.image_dir, image_filename)
# Open the image using PIL (make sure it's in RGB format)
image = Image.open(image_path).convert("RGB")
# Use the processor to handle image preprocessing
image_input = self.processor(images=image, return_tensors="pt")
# Get captions from the dictionary
captions = self.captions_dict[image_filename]
# Process captions using the tokenizer (assuming it's a separate tokenizer for text)
caption_inputs = [self.processor.tokenizer(caption, return_tensors="pt", padding=True, truncation=True) for caption in captions]
return image_input, caption_inputs
def collate_fn(batch):
# batch is a list of tuples (image_inputs, caption_inputs)
# Separate images and captions
image_inputs = [item[0] for item in batch]
caption_inputs = [item[1] for item in batch]
# Use the tokenizer's padding and truncation feature to ensure all captions have the same size
# processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
# Extract input_ids for each caption and apply padding and truncation
caption_input_ids = [processor(caption['caption'], padding="max_length", truncation=True, return_tensors="pt")['input_ids'] for caption in caption_inputs]
# Stack all caption input_ids into a single tensor (ensure they're padded to the same length)
caption_input_ids = torch.cat(caption_input_ids, dim=0)
# Return image inputs and the padded caption input_ids
return image_inputs, caption_input_ids
Hyperparameters
batch_size = 8
learning_rate = 5e-5
num_epochs = 3
Load captions
captions_file = â/Users/arpitsharma/cvision/captions.txtâ # Path to your captions file
image_dir = â/Users/arpitsharma/cvision/Images/flickr30k_imagesâ # Path to your images
captions_dict = load_captions_from_txt(captions_file)
Preprocess the dataset
preprocessed_data = preprocess_dataset(image_dir, captions_dict, processor)
Create DataLoader
dataset = CaptioningDataset(image_dir, captions_dict, processor)
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
Initialize the optimizer
optimizer = AdamW(model.parameters(), lr=learning_rate)
Training loop
device = torch.device(âmpsâ if torch.backends.mps.is_available() else âcpuâ) # Use MPS for Apple Silicon
model.to(device)
for epoch in range(num_epochs):
model.train() # Set model to training mode
total_loss = 0
for batch_idx, (image_inputs, caption_inputs) in enumerate(train_dataloader):
image_inputs = {key: value.to(device) for key, value in image_inputs.items()} # Move inputs to device
input_ids = caption_inputs['input_ids'].to(device)
attention_mask = caption_inputs['attention_mask'].to(device)
# Forward pass
outputs = model(input_ids=input_ids, attention_mask=attention_mask, images=image_inputs['pixel_values'], labels=input_ids)
loss = outputs.loss
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
if batch_idx % 10 == 0: # Print loss every 10 batches
print(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(train_dataloader)}], Loss: {loss.item()}")
print(f"Epoch [{epoch+1}/{num_epochs}] completed. Total loss: {total_loss}")
Save the fine-tuned model
output_dir = â/Users/arpitsharma/cvisionâ
if not os.path.exists(output_dir):
os.makedirs(output_dir)
model.save_pretrained(output_dir)
processor.save_pretrained(output_dir)
print(f"Fine-tuned model saved to {output_dir}")