Fine-tuning ViT with more patches/higher resolution

Hi there,
A huge thank you in advance for everyone’s help - really love this forum!!
I would like to fine-tune a ViT at higher resolution, starting with a pretrained model that is trained at 384x384. Is this as simple as creating a new ViTFeatureExtractor and passing interpolate_pos_encoding=True along with pixel_values during training? It seems to me for TRAINING something at higher resolution you would like to be able to train new position encodings instead of interpolating…

Have been googling around a lot for this… wonder if anyone has a good recipe to start with a pre-trained ViT and increase the number of patches during fine tuning… it seems that many of the pre-trained models are actually trained in this way (224x224 then 384x384… always with 16x16 patches, so more patches… longer sequence length). When I try this naively using HuggingFace it just says the image size does not match the model’s image_size : )

Thank you again.

Hi @mohotmoz ,
I think you just need to

  1. set your target size (>224) in your initial data transforms,
  2. turn on interpolate_pos_encoding in your forward pass, both at finetuning and evaluating.
1 Like

Is it possible to fine-tune ViT at a higher resolution using the Trainer class? If so, how?

yup - with the caveat that this is just from my limited experience - two things to do:

  1. Adjust your feature extractor to be a different resolution. It can even do resizing for you if you want.
  2. Pass in interpolate_pos_encoding=True in your forward pass. The way I have done this is in the past is by wrapping the standard default_data_collator in a small function:
def my_collate(ins):
    thedict = transformers.default_data_collator(ins)
    thedict["interpolate_pos_encoding"] = True
    return thedict

Then in the trainer (assuming you are using the HF trainer which is amazingly easy to use):

trainer = ... (
data_collator=my_collate
...
)

Only thing to watch out for is as your resolution increases, due to the quadratic self-attention of vanilla ViT, your memory usage goes up for training. One way of getting around this is (at the cost of time) is to enable gradient_checkpointing in your TrainingArguments. Using fsdp or deepspeed or similar tooling also helps in this regard (on multi-gpu jobs).

A huge shout-out to the amazing HF team for making everything so easy to use.

1 Like