Finetune BLIP on customer dataset #20893

Dear the team,

I was trying to finetune BLIP and so far I got an error, not sure how to solve it. Is it possible that you can give me some advice? Thanks

from PIL import Image
import requests
from transformers import BlipProcessor, BlipForQuestionAnswering

model = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-base")
processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-base")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

import torch
from PIL import Image

class VQADataset(torch.utils.data.Dataset):
    """VQA (v2) dataset."""

    def __init__(self, questions, answers, image_paths, processor):
        self.questions = questions
        self.answers = answers
        self.image_paths = image_paths
        self.processor = processor

    def __len__(self):
        return len(self.questions)

    def __getitem__(self, idx):
        # get image + text
        question = self.questions[idx]
        answer = self.answers[idx]
        image = Image.open(self.image_paths[idx]).convert("RGB")
        text = question
        
        encoding = self.processor(image, text, padding="max_length", truncation=True, return_tensors="pt")
        labels = self.processor.tokenizer.encode(
            answer, max_length= 512, pad_to_max_length=True, return_tensors='pt'
        )
        encoding["labels"] = labels

        # remove batch dimension
        # for k,v in encoding.items():  encoding[k] = v.squeeze()
        return encoding

from torch.utils.data import DataLoader
from tqdm import tqdm

def collate_fn(batch):
    input_ids = [item['input_ids'] for item in batch]
    pixel_values = [item['pixel_values'] for item in batch]
    attention_mask = [item['attention_mask'] for item in batch]
    labels = [item['labels'] for item in batch]
    
    return batch


questions = list of questions 
answers = list of corresponding answers
image_paths = list of paths of corresponding images

train_dataset = VQADataset(questions = questions,
                          answers = answers,
                          image_paths = images,
                          processor=processor)

test_dataset = VQADataset(questions = questions,
                          answers = answers,
                          image_paths = images,
                          processor=processor)

batch_size = 1
train_dataloader = DataLoader(train_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, collate_fn=collate_fn, batch_size=batch_size, shuffle=False)

batch = next(iter(train_dataloader))

print(batch[0].keys()) # dict_keys(['pixel_values', 'input_ids', 'attention_mask', 'labels'])

import copy 
test_input = copy.copy(batch[0]).to(device)
outputs = model(**test_input)

Example of the input:

questions = ["How many cats are there?"]
answers = ["two"]
image_paths = ["./img_125.png"]

Output:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
[<ipython-input-27-f4758beea430>](https://localhost:8080/#) in <module>
      2 
      3 test_input = copy.copy(batch[0]).to(device)
----> 4 outputs = model(**test_input)

6 frames
[/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py](https://localhost:8080/#) in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3024     if size_average is not None or reduce is not None:
   3025         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   3027 
   3028 

ValueError: Expected input batch_size (0) to match target batch_size (511).

Tag you as you suggested @younesbelkada.

Hi @dxlong2000
Thanks for the issue!
Could you share the full traceback of the error?

Hi @ybelkada, below is the full Traceback:

Traceback (most recent call last):
  File "<ipython-input-20-ab236b1c9c06>", line 7, in <module>
    outputs = model(**test_input)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/blip/modeling_blip.py", line 1200, in forward
    answer_output = self.text_decoder(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/blip/modeling_blip_text.py", line 903, in forward
    lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/loss.py", line 1174, in forward
    return F.cross_entropy(input, target, weight=self.weight,
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/functional.py", line 3026, in cross_entropy
    return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
ValueError: Expected input batch_size (0) to match target batch_size (511).

Hi @ybelkada, any idea?

hi @dxlong2000

This should be supported in `blip` support for training by younesbelkada · Pull Request #21021 · huggingface/transformers · GitHub
If you want to try it out now:

pip install git+https://github.com/younesbelkada/transformers@blip-train-support

Thanks!

@ybelkada : I am trying to use BLIP model from HuggingFace but it seems that is not yet part of transformers as I am getting this error:
"cannot import name ‘BlipProcessor’ from ‘transformers’ "

I installed transformers and huggingface in PIP

do you know by chance what is the problem?

Hi @Shahabhm

You should install transformers from the main branch to benefit from it:

pip install git+https://github.com/transformers/transformers@main

@dxlong2000 now that the training support has been merged you can install the library from source as above and try running your script.

Hi @ybelkada, ohh very nice! Thank you very much!

Thank you @ybelkada
I was trying to fine tune BLIP image captioning on custom dataset, based on the following example : Google Colab

However, I am getting Out of Memory (running in 1 GPU), even with batch size = 8 and using half precision model (float16)

Any idea how to solve this?

best
Shahab

Hi @Shahabhm

Could you share a script or a Colab where you observe your issue? Thanks!

One thing I wanted to explore is to use gradient accumulation by using micro batch sizes

@ybelkada Sure, but the image/text dataset can not be shared unfortunately and code wise I have not changed a lot
I have mostly changed dataloading part.
One  more important thing:
The Loss becomes "inf" in first batch and then NaN in subsequent batches. Does it tell anything?


Here is code;

"""
!pip install git+https://github.com/huggingface/transformers.git@main


!pip install -U albumentations


import os

import gc
import numpy as np
import pandas as pd
import itertools
from tqdm import tqdm
import albumentations as A

import torch
from torch import nn
import torch.nn.functional as F

from transformers import DistilBertModel, DistilBertConfig, DistilBertTokenizer
pd_padma_existing = pd.read_csv("loading daata")['custom_desc','castor','image']


class CFG:
    # text length
    max_length = 500 
    # image size
    size = 224




"""## Create PyTorch Dataset

The lines below are entirely copied from the original notebook!
"""

from torch.utils.data import Dataset, DataLoader

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

# COMMAND ----------

import cv2

class BLIPDataset(torch.utils.data.Dataset):
    def __init__(self, image_filenames, captions, castors, processor):
        """
        image_filenames and cpations must have the same length; so, if there are
        multiple captions for each image, the image_filenames must have repetitive
        file names 
        """
 
        self.image_filenames = image_filenames
        self.captions = captions
        self.castors = castors
        self.processor = processor      
        self.transforms= A.Compose(
            [
                A.Resize(CFG.size, CFG.size, always_apply=True),
                A.Normalize(max_pixel_value=255.0, always_apply=True),
            ])
    
 
    def __getitem__(self, idx):
       
        image = cv2.imread(self.image_filenames[idx])
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = self.transforms(image=image)['image']
        item_image = torch.tensor(image).permute(2, 0, 1).float()
        item_text = self.captions[idx][:300]       
        encoding = self.processor(images=item_image, text=item_text, padding="max_length", return_tensors="pt").to("cuda", torch.float16)
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        b = {}
        b['encoding'] = encoding
        b['item_image'] = item_image
        b['item_text'] = item_text
        b['castor'] = self.castors[idx]
        #return encoding, item_image, item_text
        return b
 
 
    def __len__(self):
        return len(self.captions)
 
 

# COMMAND ----------

"""## Load model and processor"""

from transformers import BlipProcessor, AutoProcessor


processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

blip_dataset = BLIPDataset(
      pd_padma_existing["image"].values,
      pd_padma_existing["caption"].values,
      pd_padma_existing["castor"].values,
      processor
  )



# COMMAND ----------

from transformers import BlipProcessor, BlipForImageTextRetrieval,BlipForConditionalGeneration
import torch



"""Now that we have loaded the processor, let's load the dataset and the dataloader:"""

train_dataloader = DataLoader(blip_dataset, shuffle=False, batch_size=32)

"""## Train the model

Let's train the model! Run the simply the cell below for training the model
"""

model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16)
  #model = BlipForImageTextRetrieval.from_pretrained("Salesforce/blip-itm-large-flickr", torch_dtype=torch.float16)

  optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

  device = "cuda" if torch.cuda.is_available() else "cpu"
  model.to(device)

  model.train()

  for epoch in range(1):
    print("Epoch:", epoch)
    for idx, b in enumerate(train_dataloader):

      batch, item_image, item_text, castor = b['encoding'], b['item_image'],b['item_text'], b['castor']
      input_ids = batch.pop("input_ids").to(device)
      pixel_values = batch.pop("pixel_values").to(device)

      outputs = model(input_ids=input_ids,
                      pixel_values=pixel_values,
                      labels=input_ids)
      #outputs = model(input_ids=input_ids, pixel_values=pixel_values)

      loss = outputs.loss

    
      print("idx=",idx, "Loss:", loss.item())

      loss.backward()

      optimizer.step()
      optimizer.zero_grad()
      input_ids = input_ids.to("cpu")
      pixel_values = pixel_values.to("cpu")

@ybelkada : one more question, I am interested in using embeddings from BLIP (after fine tuning) for downstream tasks like classification. From which layer or which output should I consider to extract these embeddings?

Hi,
Thanks for the message.
Training in pure fp16 seems to be unstable indeed. Hence, I would advice you to use torch.amp.autocast instead, check this nice recent thread from PyTorch on why this is unstable: Incorrect MSE loss for float16 - #2 by ptrblck - PyTorch Forums
Therefore replacing the training loop with the one below worked for me with batch_size=8 :

import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(50):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    with torch.autocast(device_type='cuda', dtype=torch.float16):
      outputs = model(input_ids=input_ids,
                      pixel_values=pixel_values,
                      labels=input_ids)
    
    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

Note that you need to keep your model in fp32, i.e. no need to load your model with torch_dtype=torch.float16

regarding your second question this depends on what tasks exactly, text classification or image classification? Can you give more description on the downstream tasks you want to perform?

Thanks for your hints , I will ty ity shortly. But two important things:
1- Loss value becomes “inf” n first batch and then “NaN” in later batches. What did it go wrong?
2- My goal is to group items based on their embeddings similarity and also classify them into concepts like “table”, “chair”,…

@ybelkada : it only works for me with batch size=4 !!!
but does this small batch size reliable?

Thanks for your reply
In this case you can combine the script I just sent with gradient accumulation with accumulation_step=2 to be equivalent with batch_size=8. Check this nice blogpost: Gradient Accumulation in PyTorch | Nikita Kozodoi !

Regarding 2- I think you might be interested in retrieving the image embeddings. You can retrieve them with:

model.vision_model(**inputs)[0]

Wow Thank you! very useful tips.
Without accumulative loss: the loss bounds around 2, and not improving.
Is it expected loss?

Hi @Shahabhm
Thanks for the message
Here I don’t really know, my gut feeling is that you need to apply some data augmentation on your dataset (e.g. random flipping etc)