My dataset consists of image path, resolution vector of size 5 and 4 labels – it is multi-label classification problem.
I am using the google/vit-base-patch16-224-in21k pretrained vision transformer and i want to concatenate the resolution to images/features ( don’t know which is more suitable) I have tried 2 approaches, one does not work unless i do some hard-coding, and I am not too sure of the second approach where i concatenate the resolution with the CLS-token.
First approach :
“”" class ResizingModel(MultilabelImageClassificationBase):
def init(self, num_classes, model_name=‘vit_base_patch16_224’, pretrained=True):
if model_name == ‘vit_base_patch16_224’:
self.model = vision_transformer.vit_base_patch16_224(pretrained=pretrained)
raise ValueError("Invalid model name. Supported model: ‘vit_base_patch16_224’)
fc_size = self.model.head.in_features
self.fc = nn.Linear(fc_size + 5, num_classes)
def forward(self, x,target_resolution_one_hot):
x = x.view(x.size(0), -1)
x =[x, target_resolution_one_hot], dim=1)
x = self.fc(x)
x = torch.sigmoid(x)
return x '''
Error : ‘’’ RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x151301 and 773x4) “”
I get this error when i try running the following:
‘’'for images, res, labels in train_loader:
images =
res =
labels =
outputs = model(images, res)
print(‘outputs.shape : ‘, outputs.shape)
print(‘Sample outputs :\n’, outputs[:2].data)’’’
Second approach :
‘’'class ResizingModel(MultilabelImageClassificationBase):
def init(self, num_classes):
super(ResizingModel, self).init()
self.num_classes = num_classes
self.vit = ViTModel.from_pretrained(‘google/vit-base- patch16-224-in21k’)
self.fc = nn.Linear(self.vit.config.hidden_size + 5, num_classes)
# print(f"Input size of self.cls_embedding_size: {self.vit.config.hidden_size + 5}")
# print(f"Output size of self.fc: {num_classes}")
def forward(self, x, target_resolution_one_hot):
# Process input image through Vision Transformer
vit_output = self.vit(x)
# Extract the CLS token embedding from the Vision Transformer output from the last hidden state
#the hidden state refers to the information
cls_token_embedding = vit_output.last_hidden_state[:, 0, :] # Use only the CLS token
# print("Shape of cls_token_embedding:", cls_token_embedding.shape) # Shape
# print("Shape of target res:", target_resolution_one_hot.shape) # Shape
concatenated_input =[cls_token_embedding, target_resolution_one_hot], dim=1)
# Apply linear layer
logits = self.fc(concatenated_input)
# out = torch.sigmoid(out)
return logits '''
These are some useful class definitions:
‘’’ class MultiClass(Dataset):
def init(self, dataframe, transform = None):
self.dataframe = dataframe
self.transform = transform
self.file_names = dataframe.index
self.CR = dataframe.CR.values.tolist()
self.SC = dataframe.SC.values.tolist()
self.SNS = dataframe.SNS.values.tolist()
self.SCL = dataframe.SCL.values.tolist()
self.res = dataframe.resolution.values.tolist()
def len(self):
return len(self.dataframe)
def getitem(self, index):
image =[index]).convert(‘RGB’)
label = torch.tensor(np.array([self.CR[index],self.SC[index], self.SNS[index], self.SCL[index]]))
res = torch.tensor(self.res[index])
sample = image, res, label
if self.transform:
image = self.transform(sample[0])
sample = image, res, label
return sample ‘’’