Is it possible to train ViT with different number of patches in every batch? (Non-square images dataset)

Hi everyone.
I have an image classification dataset consisting of non-square images with different sizes each of them.

Training CNN, I used to rescale them to have 224 longer side and pad with zeros other side to make them square.
Then I decided to use ViT and figured out zero padding drastically affect classification performance since lot of patches have only zeros.
Random cropping and force rescaling to be square does not work because it is important to include all of the object in image and preserve w/h ratios.

What I want to do that is feeding input as varying sizes by rescaling to have 224 on longer side and X on shorter. I know that tensors in the same batch must have same shape so assume it is done when collate. ( size: (BatchSize, 3, 224, 160) for example).

I have loaded a ViTModel and tried giving different different size of inputs to see how outputs are. There are 49 patches + 1 cls = 50 patches with 112,112 input. But when I make one dim to 110, I lost 7 patches.

I have no idea how pos encoding interpolation done and if it is right to use interpolate_pos_encoding=True parameter like this.
My question is does it make sense to train it with different shape non-square batches? What do you suggest?

Thanks.

ViT “patchifies” images using a convolutional 2D layer with the kernel size and stride equal to model.config.patch_size.

By default, ViT uses a patch size of 16, hence (112 / 16) ** 2 + 1 = 50 patches. This can be verified as follows:

from torch import nn
import torch

pixel_values = torch.randn(1, 3, 112, 112)

embedding = nn.Conv2d(in_channels=3, out_channels=768, kernel_size=16, stride=16)
print(embedding(pixel_values).shape)

This gives torch.Size([1, 768, 7, 7]), so you get a grid of 7 x 7 patches (each of dimension 768), and afterwards the special CLS token is appended so you end up with 50 patch tokens.

However if you use an image size of 110 x 112, you’ll get a different amount of patches:

pixel_values = torch.randn(1, 3, 110, 112)

embedding = nn.Conv2d(in_channels=3, out_channels=768, kernel_size=16, stride=16)
embedding(pixel_values).shape

You’ll get torch.Size([1, 768, 6, 7]), so 6 x 7 + 1 = 43 patches.

So ViT can be trained on variable sized images, by padding the variable-length patch sequences with a special “padding” token to make sure they are all of the same length (similar to how this is done in NLP). This would require some custom code in modeling_vit.py but is doable.

The interpolate_pos_encoding flag is there to make sure that each of the patches get an appropriate position embedding.

2 Likes

Thanks Niels for the great explanation.

I wonder which is the best way to perform variable sized images training, although both ways are similar.
First choice is like you described:

What I couldn’t figure out is how to stack different size of images (tensors) as batch, just before feed into ViT.

And the second is by padding shorter edges to make smaller images have same shape as the largest image in that batch. I think it could be done in dataloader’s collate_fn. This way using padding or masking is not needed. But still have some zero pads in images.

What do you think?

The best way as of now is probably how Pix2Struct (and newer models such as Idefics2) handle it.

They maintain the aspect ratio of images (unlike a standard ViT as shown on the right, which destroys the aspect ratio by squaring each image with the same resolution). This means that each image may have a different amount of patches, and they pad them all up to the same length (e.g. 2048).

This is implemented here: transformers/src/transformers/models/pix2struct/image_processing_pix2struct.py at bbaa8ceff696c479aecdb4575b2deb1349efd3aa · huggingface/transformers · GitHub. Do note that this only works for models which are trained in this way (like Pix2Struct).

However, for ViT, the only option is to pass interpolate_pos_encoding=True. One could train with variable-sized images, but one needs to create batches of images for training which means that one would need to pad them with zeros up to the largest one in a batch (DETR and related models also do this as seen here).