Fine tuning CLIP Transformer for downstream task

Hi everyone!

I’m currently trying to fine tune a pre-trained CLIP model to fit into a classification task of mine.

Right now I’m stuck trying to understand how to set this up using PyTorch.

class CLIPClassifier(nn.Module):
def __init__(self, num_classes):
        super(CLIPClassifier, self).__init__()
        self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
        self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        self.classification_head = nn.Linear(512, num_classes)

    def forward(self, *inputs):
        # Step 3: Forward pass through the CLIP model and classification head
        clip_output = self.clip_model(*inputs)
        logits = self.classification_head(clip_output)
        return logits

This is what I’ve reached to in terms of creating the CLIP model; however, it’s lacking a lot of stuff. Specifically how to get the text and images into this classifier as well as how to pass this into softmax to get me the proabilities. A little help or guidance would be much appreciated.

This is the dataset module I have set up:


class CustomDataset(Dataset):
    def __init__(self, csv_file: str, image_file: str, tokenizer=None, transforms=None):
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.image_file = image_file
        self.transforms = transforms

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        tweet_id = self.data["tweet_id"][index]
        image_path = rf"{self.image_file}\{tweet_id}.jpg"
        image = Image.open(image_path)
        tweet_text = self.data["tweet_text"][index]
        stance = self.data["stance"][index]
        persuasiveness = self.data["persuasiveness"][index]

        if stance == "oppose":
            stance = 0
        elif stance == "support":
            stance = 1

        if persuasiveness == "no":
            persuasiveness = 0
        elif persuasiveness == "yes":
            persuasiveness = 1

        if image.mode != "RGB":
            image = image.convert("RGB")

        if self.tokenizer is not None:
            encoding = self.tokenizer(tweet_text, padding=True, truncation=True)
            input_ids = torch.tensor(encoding["input_ids"]).to(
                torch.device("cuda" if torch.cuda.is_available() else "cpu")
            )
            attention_mask = torch.tensor(encoding["attention_mask"]).to(
                torch.device("cuda" if torch.cuda.is_available() else "cpu")
            )
            tweet_text = [input_ids, attention_mask]

        if self.transforms:
            image = self.transforms(image)

        return {
            "image": image,
            "tweet_text": tweet_text,
            "stance": stance,
            "persuasiveness": persuasiveness,
        }