Pyramid Vision Transformer: Issue with input image size larger than 224 px

I am working with pvt-small-224 (Pyramid Vision Transformer), which is pre-trained with images of size 3 x 224 x 224, and I want to do fine-tuning with images of higher resolution (3 x 448 x 448).

The model works fine when I feed it an image of size 224 px, but I get an error when I feed it the same image resized to 448 px.

Some prints that I added to the original code (method interpolate_pos_encoding from class PvtPatchEmbeddings in file transformers/models/pvt/modeling_pvt.py) seem to indicate that the issue is a mismatch in dimensions between the image embedding and the interpolated positional embeddings. I have checked thoroughly, but I do not know if this error arises because of a bug in the PVT code or because I have to activate some option related to positional embedding interpolation.

Below is a fast and reproducible Python script that downloads an image, resizes it to 224 px and 448 px, and runs a forward pass with PVT-S on both images.

from PIL import Image
import requests

import torch
from torchvision import transforms

from transformers import PvtModel, AutoImageProcessor


if __name__ == "__main__":

    # Create pre-trained model

    model = PvtModel.from_pretrained("Zetatech/pvt-small-224")

    # Load image (any other image can be used)

    img_url = "https://images.pexels.com/photos/949670/pexels-photo-949670.jpeg"
    img_raw = Image.open(requests.get(img_url, stream=True).raw)
    img_original = transforms.ToTensor()(img_raw)

    ##
    ## Evaluate with image size 3 x 224 x 224
    ##

    # Create image pre-processor
    
    ctsrbm_image_transform = AutoImageProcessor.from_pretrained("Zetatech/pvt-small-224")
    ctsrbm_image_transform_corr = lambda t: torch.from_numpy(ctsrbm_image_transform(t).pixel_values[0])

    # Pre-process image

    img_preprocessed = ctsrbm_image_transform_corr(img_original)[None, :]
    print("standard resolution input shape: ", img_preprocessed.shape)

    # PVT forward pass

    output = model(img_preprocessed).last_hidden_state
    print("standard resolution output shape:", output.shape)

    ##
    ## Evaluate with image size 3 x 448 x 448
    ##

    # Create image pre-processor
    
    ctsrbm_image_transform = AutoImageProcessor.from_pretrained("Zetatech/pvt-small-224")
    ctsrbm_image_transform.size["height"] = 448
    ctsrbm_image_transform.size["width"] = 448
    ctsrbm_image_transform_corr = lambda t: torch.from_numpy(ctsrbm_image_transform(t).pixel_values[0])

    # Pre-process image

    img_preprocessed = ctsrbm_image_transform_corr(img_original)[None, :]
    print("higher resolution input shape:   ", img_preprocessed.shape)

    # PVT forward pass

    output = model(img_preprocessed).last_hidden_state
    print("higher resolution output shape:  ", output.shape)

When run, the expected output is the following:

It looks like you are trying to rescale already rescaled images. If the input images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again.
standard resolution input shape:  torch.Size([1, 3, 224, 224])
standard resolution output shape: torch.Size([1, 50, 512])
higher resolution input shape:    torch.Size([1, 3, 448, 448])
Traceback (most recent call last):
  File "/home-net/gortega/fashion_retrieval/nn_backbone_pvt-small-224_explore copy.py", line 59, in <module>
    output = model(img_preprocessed).last_hidden_state
  File "/home-net/gortega/fashion_retrieval/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home-net/gortega/fashion_retrieval/.venv/lib/python3.8/site-packages/transformers/models/pvt/modeling_pvt.py", line 569, in forward
    encoder_outputs = self.encoder(
  File "/home-net/gortega/fashion_retrieval/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home-net/gortega/fashion_retrieval/.venv/lib/python3.8/site-packages/transformers/models/pvt/modeling_pvt.py", line 436, in forward
    hidden_states, height, width = embedding_layer(hidden_states)
  File "/home-net/gortega/fashion_retrieval/.venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home-net/gortega/fashion_retrieval/.venv/lib/python3.8/site-packages/transformers/models/pvt/modeling_pvt.py", line 156, in forward
    embeddings = self.dropout(embeddings + position_embeddings)
RuntimeError: The size of tensor a (64) must match the size of tensor b (16) at non-singleton dimension 2

Below I provide the output from transformers-cli env as well:

- `transformers` version: 4.33.1
- Platform: Linux-5.15.0-71-generic-x86_64-with-glibc2.29
- Python version: 3.8.10
- Huggingface_hub version: 0.17.1
- Safetensors version: 0.3.3
- Accelerate version: not installed
- Accelerate config: not found
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: NO
- Using distributed or parallel set-up in script?: NO