I am attempting to adapt the Blip2Model for a zero-shot classification task as follows:
- N text sentences/classes → x = N text embeddings
- 1 test image → y = 1 image embedding
- soft-max(dot-product(x, y)) to get the probabilities over classes
This is my solution so far:
def get_img_embedding(images]): """ Turn a list of image inputs into tensor of embedding vectors images should be of shape (batch_size, channels, height, width) """ image_tensors = blip2model.preproc([ Image.open(i.path) # type: ignore for i in images], return_tensors='pt') # Dict with 'pixel_values' entry of size batch_size, C, H, W image_tensors = image_tensors.to(self.device, torch.float16) # type: ignore # pass images through the vision model and then the qformer to get query-conditional image features query_outputs = blip2model.get_qformer_features(**image_tensors) # tuple (last_hidden_state, pooler_output) query_output = query_outputs['pooler_output'] # (batch_size, hidden_size) # project query-conditional image features into language space image_features = blip2model.language_projection(query_output) # shape (batch_size, hidden_size) image_features /= image_features.norm(dim=-1, keepdim=True) return image_features def get_text_embedding(texts): """ Turn a list of text inputs into tensor of embedding vectors. texts is a list of strings to embed. """ text_tokens = blip2model.text_tokenizer(texts, padding=True, return_tensors='pt') text_tokens = text_tokens.to(self.device) text_outputs = blip2model.get_text_features(**text_tokens, output_hidden_states=True) # type: ignore text_features = text_outputs['hidden_states'][-1][:, 0, :] # extract [CLS] embedding from last hidden state, shape (batch_size, hidden_size) text_features /= text_features.norm(dim=-1, keepdim=True) return text_features
Then I would take the dot product between the two. Am I on the right track?