Concatenating in vision transformer

My dataset consists of image path, resolution vector of size 5 and 4 labels – it is multi-label classification problem.
I am using the google/vit-base-patch16-224-in21k pretrained vision transformer and i want to concatenate the resolution to images/features ( don’t know which is more suitable) I have tried 2 approaches, one does not work unless i do some hard-coding, and I am not too sure of the second approach where i concatenate the resolution with the CLS-token.

First approach :

“”" class ResizingModel(MultilabelImageClassificationBase):
def init(self, num_classes, model_name=‘vit_base_patch16_224’, pretrained=True):
super().init()
if model_name == ‘vit_base_patch16_224’:
self.model = vision_transformer.vit_base_patch16_224(pretrained=pretrained)
else:
raise ValueError("Invalid model name. Supported model: ‘vit_base_patch16_224’)

    fc_size = self.model.head.in_features
    self.fc = nn.Linear(fc_size + 5, num_classes)

def forward(self, x,target_resolution_one_hot):
    x = x.view(x.size(0), -1)
    x = torch.cat([x, target_resolution_one_hot], dim=1)
    x = self.fc(x)
    x = torch.sigmoid(x)
    return x '''

Error : ‘’’ RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x151301 and 773x4) “”

I get this error when i try running the following:
‘’'for images, res, labels in train_loader:
print(images.shape)
images = images.to(device)
res = res.to(device)
labels = labels.to(device)
outputs = model(images, res)
break

print(‘outputs.shape : ‘, outputs.shape)
print(‘Sample outputs :\n’, outputs[:2].data)’’’

Second approach :

‘’'class ResizingModel(MultilabelImageClassificationBase):
def init(self, num_classes):
super(ResizingModel, self).init()
self.num_classes = num_classes
self.vit = ViTModel.from_pretrained(‘google/vit-base- patch16-224-in21k’)
self.fc = nn.Linear(self.vit.config.hidden_size + 5, num_classes)

    # print(f"Input size of self.cls_embedding_size: {self.vit.config.hidden_size + 5}")
    # print(f"Output size of self.fc: {num_classes}")

def forward(self, x, target_resolution_one_hot):
    # Process input image through Vision Transformer
    vit_output = self.vit(x)

    # Extract the CLS token embedding from the Vision Transformer output from the last hidden state
    #the hidden state refers to the information
    cls_token_embedding = vit_output.last_hidden_state[:, 0, :]  # Use only the CLS token
    # print("Shape of cls_token_embedding:", cls_token_embedding.shape)  # Shape
    # print("Shape of target res:", target_resolution_one_hot.shape)  # Shape
    concatenated_input = torch.cat([cls_token_embedding, target_resolution_one_hot], dim=1)


    # Apply linear layer
    logits = self.fc(concatenated_input)
    # out = torch.sigmoid(out)
    return logits '''

These are some useful class definitions:

‘’’ class MultiClass(Dataset):
def init(self, dataframe, transform = None):
self.dataframe = dataframe
self.transform = transform
self.file_names = dataframe.index
self.CR = dataframe.CR.values.tolist()
self.SC = dataframe.SC.values.tolist()
self.SNS = dataframe.SNS.values.tolist()
self.SCL = dataframe.SCL.values.tolist()
self.res = dataframe.resolution.values.tolist()
def len(self):
return len(self.dataframe)
def getitem(self, index):
image = Image.open(self.file_names[index]).convert(‘RGB’)
label = torch.tensor(np.array([self.CR[index],self.SC[index], self.SNS[index], self.SCL[index]]))
res = torch.tensor(self.res[index])
sample = image, res, label
if self.transform:
image = self.transform(sample[0])
sample = image, res, label
return sample ‘’’