Can BlipForImageTextRetrieval be used to generate captions?

I am seeking a Blip model that can serve two purposes: predicting the similarity between an input image and text and generating a caption for an input image. I am aware that BlipForImageTextRetrieval is suitable for predicting the similarity between an image and text, while BlipForConditionalGeneration can generate captions for images. However, I was wondering whether either of these models can be employed to perform the alternate task as well.

A bit more context: I have a fine-tuned BlipForImageTextRetrieval model that I would like to use for generating captions.

Any guidance on obtaining a Blip model that can do both the tasks mentioned above would be extremely helpful. Thanks.

Hi,

BlipForImageTextRetrieval consists of a vision encoder and a text encoder (similar to CLIP), and is trained using an image-text contrastive (ITC) loss.

BlipForConditionalGeneration consists of a vision encoder and a text decoder, and is trained using a language modeling loss.

Hence, you could create a new model (let’s call it BlipForRetrievalAndCaptioning) that includes the vision encoder, text encoder and text decoder. This model can then handle both retrieval by leveraging its vision and text encoders, and handle captioning by leveraging its vision encoder and text decoder. Its implementation looks like this:

from transformers import BlipPreTrainedModel, BlipConfig
from transformers.models.blip.modeling_blip import BlipVisionModel
from transformers.models.blip.modeling_blip_text import BlipTextModel, BlipTextLMHeadModel

class BlipForRetrievalAndCaptioning(BlipPreTrainedModel):
    config_class = BlipConfig
    _tied_weights_keys = ["text_decoder.cls.predictions.decoder.bias"]
    main_input_name = "pixel_values"

    def __init__(self, config: BlipConfig):
        super().__init__(config)

        self.vision_model = BlipVisionModel(config.vision_config)
        self.text_encoder = BlipTextModel(config.text_config, add_pooling_layer=False)
        self.text_decoder = BlipTextLMHeadModel(config.text_config)

        self.decoder_input_ids = config.text_config.bos_token_id
        self.decoder_pad_token_id = config.text_config.pad_token_id

        # Initialize weights and apply final processing
        self.post_init()

   def forward(pixel_values, input_ids, do_retrieval=True):
         if do_retrieval:
             # see forward of BlipForImageTextRetrieval
         else:
            # see forward of BlipForConditionalGeneration

Thank you so much for your help @nielsr!

Hey @Vib04

I am currently trying to fine-tune a BlipForImageTextRetrieval but am stuck.
May I ask how did you fine-tune it?

Did you use a contrastive loss?

We currently have image-text pairs in our dataset. Training the ImageCaptioning works well, but are kinda stuck on how to use our dataset to train the BlipForImageTextRetrieval .

Thank you!