Extract visual and contextual features from images

Yes! The modeling file that you refer to is actually just a Vision Transformer but with a modified head, as it explicitly mentions:

"ViTSTR is basically a ViT that uses DeiT weights.
Modified head to support a sequence of characters prediction for STR."

So if you want to create a similar model using HuggingFace Transformers, you can, as ViT is available (documentation can be found here). We just need to define a similar classification head, as follows:

import torch.nn as nn
from transformers import ViTModel

class ViTSTR(nn.Module):
    def __init__(self, config, num_labels):
        super(ViTSTR, self).__init__()
        self.vit = ViTModel(config)
        self.head = nn.Linear(config.hidden_size, num_labels) if num_labels > 0 else nn.Identity()
        self.num_labels = num_labels

    def forward(self, pixel_values, seqlen=25):
        outputs = self.vit(pixel_values=pixel_values)
        # only keep seqlen last hidden states
        x = outputs.last_hidden_state[:, :seqlen]

        # batch_size, seqlen, embedding size
        b, s, e = x.size()
        x = x.reshape(b*s, e)
        x = self.head(x).view(b, s, self.num_labels)
        return x

You can then initialize the model as follows:

from transformers import ViTConfig

config = ViTConfig()
model = ViTSTR(config, num_labels=10)

Note that this doesn’t use any transfer learning. You can of course load any pre-trained ViTModel from the hub, by replacing self.vit = ViTModel(config) in the code above with self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224") for example.

1 Like