Fine tuning image transformer on higher resolution

I would like to finme tune a transformer on images 512x512. I was looking at the huggingpics notebook as a starting point but failed to modify it to train on images with size different from 224x224.
I tried adding the “size” to ViTFeatureExtractor.from_pretrained (unknown keyword) and to ViTForImageClassification.from_pretrained (passed, but when it reached the fails with “ValueError: Input image size (512512) doesn’t match model (224224).”).

While i know how to do that in plain vanilla keras and pytorch, the hugginingface is just too user friendly for me!

Any help will be greatly appreciated.


Yes you can achieve that, by passing interpolate_pos_encoding=True to the forward of the model (docs).

Explanation: so the Vision Transformer (ViT) was pre-trained on images of resolution 224x224. For instance, google/vit-base-patch16-224 · Hugging Face used images of size 224x224 and a patch resolution of 16x16. This means that, when feeding an image to the model, one gets (224/16)^2 = 196 patch tokens, which are fed to the Transformer encoder. As one also adds a special CLS token at the beginning, one actually feeds 197 tokens to it.

The model adds position embeddings to each of those patch tokens, to make the model “know” the order of these tokens. However, one only has an embedding matrix containing 197 positions during pre-training.

Suppose you’d like to fine-tune the model on a resolution of 512x512. With a patch resolution of 16x16, one has (512/16)^2 = 1024 patch tokens, and with the additional CLS token, this would mean 1025 tokens! Which means we need 1025 position embeddings.

So what usually happens is interpolation of the pre-trained position embeddings, to basically turn the embedding matrix from (197, 768) to (1025, 768) - assuming each position embedding is a vector of size 768.

See also section 3.2 of the paper.

Thank you so much! now its working!
IMHO, as this is definitely not trivial and different to other models, that perhaps a small section on finetuning should be added? as well as feature extraction (it seems to be different, much smaller vectors to, say efficientnet b7 that has 2580 or something and vit has only 768). Also the pooling mechanism is different, the the unusual max and avg pooling does not work (which makes sense as it is embeddings of patches and not actually image patches). But, all this is unexplained.

Thank you again for your help! much appreciated!


Yes definitely, let me open a PR today to add this!

And yes ViT is very much like BERT :wink: it outputs a vector of size 768 for each “patch” (which can be seen as each “word”), whereas a model like ResNet outputs a “feature map” of shape (batch_size, num_channels, height, width).

Umm i am still doing something wrong, I feel, as I can train really large batches (over 128) which shouldn’t be possible even with 24gb of gpu memory and results are about the same as with the 224
I’ll detail what I do:
This is the model i use:

def get_model(size):                                                                                                                                                                                                                          
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch32-384', size = (512,512))                                                                                                                                       
    model = ViTForImageClassification.from_pretrained(                                                                                                                                                                                                                                                                                                                                                                            
    ¦   ignore_mismatched_sizes=True                                                                                                                                                                                                          
    return model, feature_extractor  

this is the training step in the pl LightningModule class:

    def training_step(self, batch, batch_idx):                                                                                                                                                                                                
    ¦   outputs = self(**batch, interpolate_pos_encoding=True)                                                                                                                                                                                
    ¦   self.log(f"train_loss", outputs.loss)                                                                                                                                                                                                 
    ¦   return outputs.loss    

the rest is basically the same as the huggingpics notebook. what am I missing?