Extract visual and contextual features from images



I’m currently trying to build an OCR (text-recognition model), which gets the visual and contextual features via a transformer architecture. Currently my architecture looks like this:
Visual features: DeiT FeatureExtractor (can it be that this is just a preprocessing instance for the DeiT model and no forward / model is called there?) Contextual features: BiLSTM / stacked BiLSTM
Prediction: Attention

Is it possible to use the Transformers library to get the visual and contextual features via a pre-trained model? Or what would you recommend? The idea behind this is to get away from the original CRNN - LSTM architecture and get the features via Transformer to speed up the whole inference process to reach faster results as with tesseract.
(This model will later used mainly for document images with much text)
Or maybe you have even an example how to reach something like this?

One example from synthetic dataset:

and labels are the plain text

PS: i use mainly pytorch / text dtection part is done so i need only the poor recognition

Many greetings
Felix

Edit: has transformers any ready to use implementation like ViTSTR paper or any recommendations how i can rebuild it but with pretrained models from transformers lib and not timm?


The feature extractors (like ViTFeatureExtractor, DEiTFeatureExtractor) can be used to prepare images for Transformer-based models (ViT and DEiT respectively). They mainly do 2 things: resize images to a given size and normalize the channels. After using the feature extractor, an image is turned into a PyTorch tensor of shape (batch_size, num_channels, height, width), which might be (1, 3, 224, 224). Next, this tensor is provided to a Transformer that turns it into contextual features. For prediction, one typically simply places a linear classification head (nn.Linear) on top of the contextual features.

You might be interested in this project: GitHub - him4318/Transformer-ocr: Handwritten text recognition using transformers.. It’s based on DETR, which is available in HuggingFace Transformers. Note that DETR itself consists of a convolutional backbone + encoder-decoder Transformer.

Instead of using classification heads for predicting class labels + bounding boxes (as was done in the original DETR as it was meant for object detection), he simply adds a linear layer on top of the Transformer outputs, which act as a “language modeling decoder” (similar to was is done in models like BERT during pre-training). This language modeling decoder maps the contextual features of the Transformer to actual words. This language modeling decoder is defined here.


Hi Niels,

thanks for your answer i will check this but have you any recommendation to rebuild this with transformers lib without the timm model :



Yes! The modeling file that you refer to is actually just a Vision Transformer but with a modified head, as it explicitly mentions:

"ViTSTR is basically a ViT that uses DeiT weights.
Modified head to support a sequence of characters prediction for STR."

So if you want to create a similar model using HuggingFace Transformers, you can, as ViT is available (documentation can be found here). We just need to define a similar classification head, as follows:

import torch.nn as nn
from transformers import ViTModel

class ViTSTR(nn.Module):
    def __init__(self, config, num_labels):
        super(ViTSTR, self).__init__()
        self.vit = ViTModel(config)
        self.head = nn.Linear(config.hidden_size, num_labels) if num_labels > 0 else nn.Identity()
        self.num_labels = num_labels

    def forward(self, pixel_values, seqlen=25):
        outputs = self.vit(pixel_values=pixel_values)
        # only keep seqlen last hidden states
        x = outputs.last_hidden_state[:, :seqlen]

        # batch_size, seqlen, embedding size
        b, s, e = x.size()
        x = x.reshape(b*s, e)
        x = self.head(x).view(b, s, self.num_labels)
        return x

You can then initialize the model as follows:

from transformers import ViTConfig

config = ViTConfig()
model = ViTSTR(config, num_labels=10)

Note that this doesn’t use any transfer learning. You can of course load any pre-trained ViTModel from the hub, by replacing self.vit = ViTModel(config) in the code above with self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224") for example.

Hi Niels,

thanks for your posts this was very helpful

Now the model training runs but i have some feeling that the model does not learn anything.
I have run a few tests with ~1000 examples from my generated data (~50 epochs) but the loss and accuracy stuck permanently at the same values with minimal difference (up and down) . Any ideas ?

Input Example after FeatureExtractor:

pretrained used:

# pretrained_part
    if pretrained_model == 'deit_tiny':
        pretrained_model = 'facebook/deit-tiny-patch16-224'
    elif pretrained_model == 'deit_small':
        pretrained_model = 'facebook/deit-small-patch16-224'
    elif pretrained_model == 'deit_tiny_distilled':
        pretrained_model = 'facebook/deit-tiny-distilled-patch16-224'
    elif pretrained_model == 'deit_small_distilled':
        pretrained_model = 'facebook/deit-small-distilled-patch16-224'
        pretrained_model = None  # use ViTModel from config

    feature_extractor = ViTFeatureExtractor(do_resize=True, size=224, do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])


class ViTSTR(nn.Module):
    def __init__(self, pretrained_model: str, num_labels: int):
        super(ViTSTR, self).__init__()
        self.pretrained_model = pretrained_model
        if self.pretrained_model:
            self.vit = ViTModel.from_pretrained(pretrained_model)
            config = ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, layer_norm_eps=1e-12, is_encoder_decoder=False, image_size=224, patch_size=16, num_channels=3)
            self.vit = ViTModel(config)
        self.num_labels = num_labels # len of numbers and symbols
        self.head = nn.Linear(self.vit.config.hidden_size, self.num_labels) if self.num_labels > 0 else nn.Identity()

    def forward(self, pixel_values, seqlen):
        outputs = self.vit(pixel_values=pixel_values)
        # only keep seqlen last hidden states
        hidden_states = outputs.last_hidden_state[:, :seqlen]

        # batch_size, seqlen, embedding size
        b, s, e = hidden_states.size()
        hidden_states = hidden_states.reshape(b*s, e)
        vocab = self.head(hidden_states).view(b, s, self.num_labels)
        return vocab

Model Output:
torch.Size([2, 34, 110]) [batch size, 32 max seq length + 2 ([s] [go] tokens), vocab size 108 + 2 ([s] [go] tokens)]

Loss (CrossEntropy) :

torch.Size([2, 34, 110])
tensor(4.8557, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.1618, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.4171, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.0631, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.4249, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.3074, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.0706, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(4.9988, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.0464, grad_fn=<NllLossBackward>)

(On 2xGPU i have tested a bigger set with batch size 96 for 50 epochs)
but there also after 1 epoch the loss is around ~5.xx and does not disincrease up to epoch 50

Some code:

def forward(self, pixel_values, labels=None, seqlen=32):
        prediction = self.vit_model(pixel_values, seqlen)
        return prediction
def training_step(self, batch, batch_idx):
        image_tensors = batch['pixel_values']
        labels = batch['text']
        target = self.converter.encode(labels).to(device=self.device)
        size: [2, 34]
        tensor([[ 0, 51, 56, 47,  6,  3, 11, 11,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  9, 47,  9,  8,  2, 47,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
        predictions = self(image_tensors, labels=target, seqlen=self.converter.batch_max_length)
        size: [2, 34, 110]
        predictions.view(-1, predictions.shape[-1]) : [68, 110]   -> 2x34
        target.contiguous().view(-1) : [68, 110] -> 2x34

        loss = self.loss_fn(predictions.view(-1, predictions.shape[-1]), target.contiguous().view(-1))
        return {"loss": loss, "train_acc": accuracy(predictions.permute(0, 2, 1), target).detach()}
def validation_step(self, batch, batch_idx):
        # TODO: DEBUG !!! Check if its work correctly
        image_tensors = batch['pixel_values']
        labels = batch['text']
        batch_size = image_tensors.size(0)
        self.length_of_data = self.length_of_data + batch_size
        target = self.converter.encode(labels).to(device=self.device)

        predictions = self(image_tensors, labels=target, seqlen=self.converter.batch_max_length)

        _, preds_index = predictions.topk(1, dim=-1, largest=True, sorted=True)
        preds_index = preds_index.view(-1, self.converter.batch_max_length)

        loss = self.loss_fn(predictions.contiguous().view(-1, predictions.shape[-1]), target.contiguous().view(-1))

        length_for_pred = torch.IntTensor([self.converter.batch_max_length - 1] * batch_size)
        preds_str = self.converter.decode(preds_index[:, 1:], length_for_pred)

        # compute word error rate
        word_error_rate = wer(predictions=preds_str, references=labels)

        preds_prob = F.softmax(predictions, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)

        norm_ED = 0
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            pred_EOS = pred.find('[s]')
            pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            pred_max_prob = pred_max_prob[:pred_EOS]

            confidence = torch.sum(pred_max_prob) / len(pred_max_prob)

            # debug
            #print(f'ground_truth: {gt:25}  prediction: {pred:25s}  \nconfidence: {confidence}\n')

            if len(gt) == 0 or len(pred) ==0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)

        norm_ED = float(norm_ED / float(self.length_of_data)) # ICDAR2019 Normalized Edit Distance
        print(f'normalized edit distance is: {norm_ED}')

        return {"loss": loss, "val_acc": accuracy(predictions.permute(0, 2, 1), target).detach(), "word_error_rate": torch.tensor(word_error_rate).detach()}

Thanks again for your help and let me know if you need more informations :slight_smile:
PS: i have added you to the repo if you want to take a look:
→ Task_Experimential → Lightning OCR → transformer_based
Best regards

Input Example after FeatureExtractor: