yup - with the caveat that this is just from my limited experience - two things to do:
- Adjust your feature extractor to be a different resolution. It can even do resizing for you if you want.
- Pass in
interpolate_pos_encoding=True
in your forward pass. The way I have done this is in the past is by wrapping the standarddefault_data_collator
in a small function:
def my_collate(ins):
thedict = transformers.default_data_collator(ins)
thedict["interpolate_pos_encoding"] = True
return thedict
Then in the trainer (assuming you are using the HF trainer which is amazingly easy to use):
trainer = ... (
data_collator=my_collate
...
)
Only thing to watch out for is as your resolution increases, due to the quadratic self-attention of vanilla ViT, your memory usage goes up for training. One way of getting around this is (at the cost of time) is to enable gradient_checkpointing
in your TrainingArguments. Using fsdp
or deepspeed
or similar tooling also helps in this regard (on multi-gpu jobs).
A huge shout-out to the amazing HF team for making everything so easy to use.