Fine tuning image transformer on higher resolution

I would like to finme tune a transformer on images 512x512. I was looking at the huggingpics notebook as a starting point but failed to modify it to train on images with size different from 224x224.
I tried adding the “size” to ViTFeatureExtractor.from_pretrained (unknown keyword) and to ViTForImageClassification.from_pretrained (passed, but when it reached the trainer.fit fails with “ValueError: Input image size (512512) doesn’t match model (224224).”).

While i know how to do that in plain vanilla keras and pytorch, the hugginingface is just too user friendly for me!

Any help will be greatly appreciated.

Hi,

Yes you can achieve that, by passing interpolate_pos_encoding=True to the forward of the model (docs).

Explanation: so the Vision Transformer (ViT) was pre-trained on images of resolution 224x224. For instance, google/vit-base-patch16-224 · Hugging Face used images of size 224x224 and a patch resolution of 16x16. This means that, when feeding an image to the model, one gets (224/16)^2 = 196 patch tokens, which are fed to the Transformer encoder. As one also adds a special CLS token at the beginning, one actually feeds 197 tokens to it.

The model adds position embeddings to each of those patch tokens, to make the model “know” the order of these tokens. However, one only has an embedding matrix containing 197 positions during pre-training.

Suppose you’d like to fine-tune the model on a resolution of 512x512. With a patch resolution of 16x16, one has (512/16)^2 = 1024 patch tokens, and with the additional CLS token, this would mean 1025 tokens! Which means we need 1025 position embeddings.

So what usually happens is interpolation of the pre-trained position embeddings, to basically turn the embedding matrix from (197, 768) to (1025, 768) - assuming each position embedding is a vector of size 768.

See also section 3.2 of the paper.

1 Like

Thank you so much! now its working!
IMHO, as this is definitely not trivial and different to other models, that perhaps a small section on finetuning should be added? as well as feature extraction (it seems to be different, much smaller vectors to, say efficientnet b7 that has 2580 or something and vit has only 768). Also the pooling mechanism is different, the the unusual max and avg pooling does not work (which makes sense as it is embeddings of patches and not actually image patches). But, all this is unexplained.

Thank you again for your help! much appreciated!

Hi,

Yes definitely, let me open a PR today to add this!

And yes ViT is very much like BERT :wink: it outputs a vector of size 768 for each “patch” (which can be seen as each “word”), whereas a model like ResNet outputs a “feature map” of shape (batch_size, num_channels, height, width).

Umm i am still doing something wrong, I feel, as I can train really large batches (over 128) which shouldn’t be possible even with 24gb of gpu memory and results are about the same as with the 224
I’ll detail what I do:
This is the model i use:

def get_model(size):                                                                                                                                                                                                                          
                                                                                                                            
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384', size = (512,512))                                                                                                                                       
    model = ViTForImageClassification.from_pretrained(                                                                                                                                                                                                                                                                                                                                                                            
        'google/vit-base-patch32-384',                                                                                                                                                                                                        
        num_labels=2,                                                                                                                                                                                                                         
    ¦   ignore_mismatched_sizes=True                                                                                                                                                                                                          
    )                                                                                                                                                                                                                                         
    return model, feature_extractor  

this is the training step in the pl LightningModule class:

    def training_step(self, batch, batch_idx):                                                                                                                                                                                                
    ¦   outputs = self(**batch, interpolate_pos_encoding=True)                                                                                                                                                                                
    ¦   self.log(f"train_loss", outputs.loss)                                                                                                                                                                                                 
    ¦   return outputs.loss    

the rest is basically the same as the huggingpics notebook. what am I missing?

Hi,

I’m findind hard to achieve to set the interporlate_pos_encoding = True by redefining the model’s forward method.
Could you make a brief step by step case on how to do so in order to train the model, and not just make a forward pass?

I thought it was just about modifying the model.config which feeds the module parameters, as one can set the output_attentions = True per example, but I see is not the same case for the inteporlate_pos_encoding.

Thank you!

I fine tune the model by using the Trainer build class from transformers, not directly by calling the forward method, so I’m not finding the way to set the interpolate_pos_encoding to true in that case.

2 Likes

ValueError: mean must have 4 elements if it is an iterable, got 3

my program is showing above error how to solve it

1 Like

Same error to me while trying to train an image captioning model (labels size 32, pixel values size [3, 224, 224])

ValueError: mean must have 4 elements if it is an iterable, got 3

@Adhil1123 I got the solution for ValueError: mean must have 4 elements if it is an iterable, got 3

My images were in CMYK mode, so they had 4 channels intead of 3. I had to convert them to RGB with img.convert('RGB') to run the code successfuly .

Hi one can achieve this as explained here: Fine-tuning ViT with more patches/higher resolution - #4 by mohotmoz

This is a really good post and actually came at the right time, when I was reconsidering vit. Thank you!