Hi there
I am attempting to adapt the Blip2Model for a zero-shot classification task as follows:
- N text sentences/classes → x = N text embeddings
- 1 test image → y = 1 image embedding
- soft-max(dot-product(x, y)) to get the probabilities over classes
This is my solution so far:
def get_img_embedding(images]):
"""
Turn a list of image inputs into tensor of embedding vectors
images should be of shape (batch_size, channels, height, width)
"""
image_tensors = blip2model.preproc([
Image.open(i.path) # type: ignore
for i in images], return_tensors='pt') # Dict with 'pixel_values' entry of size batch_size, C, H, W
image_tensors = image_tensors.to(self.device, torch.float16) # type: ignore
# pass images through the vision model and then the qformer to get query-conditional image features
query_outputs = blip2model.get_qformer_features(**image_tensors) # tuple (last_hidden_state, pooler_output)
query_output = query_outputs['pooler_output'] # (batch_size, hidden_size)
# project query-conditional image features into language space
image_features = blip2model.language_projection(query_output) # shape (batch_size, hidden_size)
image_features /= image_features.norm(dim=-1, keepdim=True)
return image_features
def get_text_embedding(texts):
"""
Turn a list of text inputs into tensor of embedding vectors.
texts is a list of strings to embed.
"""
text_tokens = blip2model.text_tokenizer(texts, padding=True, return_tensors='pt')
text_tokens = text_tokens.to(self.device)
text_outputs = blip2model.get_text_features(**text_tokens, output_hidden_states=True) # type: ignore
text_features = text_outputs['hidden_states'][-1][:, 0, :] # extract [CLS] embedding from last hidden state, shape (batch_size, hidden_size)
text_features /= text_features.norm(dim=-1, keepdim=True)
return text_features
Then I would take the dot product between the two. Am I on the right track?
Thanks