Fine-tuning ViT with more patches/higher resolution

yup - with the caveat that this is just from my limited experience - two things to do:

  1. Adjust your feature extractor to be a different resolution. It can even do resizing for you if you want.
  2. Pass in interpolate_pos_encoding=True in your forward pass. The way I have done this is in the past is by wrapping the standard default_data_collator in a small function:
def my_collate(ins):
    thedict = transformers.default_data_collator(ins)
    thedict["interpolate_pos_encoding"] = True
    return thedict

Then in the trainer (assuming you are using the HF trainer which is amazingly easy to use):

trainer = ... (
data_collator=my_collate
...
)

Only thing to watch out for is as your resolution increases, due to the quadratic self-attention of vanilla ViT, your memory usage goes up for training. One way of getting around this is (at the cost of time) is to enable gradient_checkpointing in your TrainingArguments. Using fsdp or deepspeed or similar tooling also helps in this regard (on multi-gpu jobs).

A huge shout-out to the amazing HF team for making everything so easy to use.

1 Like