Adapting BLIP2 for zero-shot classification

Hi there

I am attempting to adapt the Blip2Model for a zero-shot classification task as follows:

  • N text sentences/classes → x = N text embeddings
  • 1 test image → y = 1 image embedding
  • soft-max(dot-product(x, y)) to get the probabilities over classes

This is my solution so far:

    def get_img_embedding(images]):
        """
        Turn a list of image inputs into tensor of embedding vectors
        images should be of shape (batch_size, channels, height, width)
        """
        image_tensors = blip2model.preproc([
            Image.open(i.path) # type: ignore
        for i in images], return_tensors='pt') # Dict with 'pixel_values' entry of size batch_size, C, H, W

        image_tensors = image_tensors.to(self.device, torch.float16) # type: ignore

        # pass images through the vision model and then the qformer to get query-conditional image features
        query_outputs = blip2model.get_qformer_features(**image_tensors) # tuple (last_hidden_state, pooler_output)
        query_output = query_outputs['pooler_output'] # (batch_size, hidden_size)
        # project query-conditional image features into language space
        image_features = blip2model.language_projection(query_output) # shape (batch_size, hidden_size)
        image_features /= image_features.norm(dim=-1, keepdim=True)

        return image_features

    def get_text_embedding(texts):
        """
        Turn a list of text inputs into tensor of embedding vectors.
        texts is a list of strings to embed.
        """

        text_tokens = blip2model.text_tokenizer(texts, padding=True, return_tensors='pt')
        text_tokens = text_tokens.to(self.device) 

        text_outputs = blip2model.get_text_features(**text_tokens, output_hidden_states=True) # type: ignore
        text_features = text_outputs['hidden_states'][-1][:, 0, :] # extract [CLS] embedding from last hidden state, shape (batch_size, hidden_size)
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

Then I would take the dot product between the two. Am I on the right track?

Thanks

Have you had any luck with this? I’m also interested in knowing the best way to get BLIP2 images and texts into the same embedding space for retrieval type of purposes. (Not caption generation per say, but instead, comparing a set of caption embeddings and image embeddings to see which has the best match.

1 Like

BTW, I don’t know if that is the correct place to ask, but in the original code (lavis package) there is an itm (image_text_matching) head, see here

ITM_head implementation

Is there a similar / parallel class in the :hugs: implementation that can help in this task?
especially with that head pretrained