Text classification using BLIP2

Hi everyone!

I was wondering whether my approach to the following problem is correct. I am trying to use the BLIP-2 model to perform classification on a small dataset.

My approach is the following:

  1. Run the prompts and images through the model (using Blip2ForConditionalGeneration)
  2. Retrieve the q-former last hidden state
  3. Create a linear layer and map the information to a intermediary dimension (say 1024).
  4. Finally map the intermediary dimension using a classification layer based on the number of classes in the problem.

Does this look fine? Below is a simplified version of my approach:

    def forward(self, data, labels):
        out = self.base_model(**data, output_hidden_states=True)
        
        # resize the features
        features = out.qformer_outputs.last_hidden_state.view(batch_size, -1)
        
        # mapping of the features to a intermediary dimension (e.g. 1024)
        x = self.fusion(features)

        # use the classifier
        logits = self.classifier(x)
       
        # apply cross entropy loss
        self.criterion(logits, labels)

As a note regarding step 1 above, as far as my understanding goes, the
q-former is the one that synthesis the information received from the image encoder and the textual information received in the prompt (according to this article). As a result, the output of the vision model can be ignored, since it was already processed by the q-former. I would be grateful for some feedback.