Adapting BLIP2 for zero-shot classification

Hi there

I am attempting to adapt the Blip2Model for a zero-shot classification task as follows:

  • N text sentences/classes → x = N text embeddings
  • 1 test image → y = 1 image embedding
  • soft-max(dot-product(x, y)) to get the probabilities over classes

This is my solution so far:

    def get_img_embedding(images]):
        """
        Turn a list of image inputs into tensor of embedding vectors
        images should be of shape (batch_size, channels, height, width)
        """
        image_tensors = blip2model.preproc([
            Image.open(i.path) # type: ignore
        for i in images], return_tensors='pt') # Dict with 'pixel_values' entry of size batch_size, C, H, W

        image_tensors = image_tensors.to(self.device, torch.float16) # type: ignore

        # pass images through the vision model and then the qformer to get query-conditional image features
        query_outputs = blip2model.get_qformer_features(**image_tensors) # tuple (last_hidden_state, pooler_output)
        query_output = query_outputs['pooler_output'] # (batch_size, hidden_size)
        # project query-conditional image features into language space
        image_features = blip2model.language_projection(query_output) # shape (batch_size, hidden_size)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        return image_features

    def get_text_embedding(texts):
        """
        Turn a list of text inputs into tensor of embedding vectors.
        texts is a list of strings to embed.
        """

        text_tokens = blip2model.text_tokenizer(texts, padding=True, return_tensors='pt')
        text_tokens = text_tokens.to(self.device) 

        text_outputs = blip2model.get_text_features(**text_tokens, output_hidden_states=True) # type: ignore
        text_features = text_outputs['hidden_states'][-1][:, 0, :] # extract [CLS] embedding from last hidden state, shape (batch_size, hidden_size)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

Then I would take the dot product between the two. Am I on the right track?

Thanks

Have you had any luck with this? I’m also interested in knowing the best way to get BLIP2 images and texts into the same embedding space for retrieval type of purposes. (Not caption generation per say, but instead, comparing a set of caption embeddings and image embeddings to see which has the best match.

1 Like

BTW, I don’t know if that is the correct place to ask, but in the original code (lavis package) there is an itm (image_text_matching) head, see here

ITM_head implementation

Is there a similar / parallel class in the :hugs: implementation that can help in this task?
especially with that head pretrained

Were you able to solve the task? I noticed that you are using a slightly different approach with respect to [1].
In the previous post, the output field qformer_outputs.last_hidden_state is used to synthesis the information from the qformer using the Blip2ForConditionalGeneration class. Your approach seems to be using Blip2Model.

As far as my understanding goes, the q-former already makes use of the vision model to generate its output. Could anyone with more experience explain which of these two methods is more effective?