Extract visual and contextual features from images

@nielsr

Hi Niels,

thanks for your posts this was very helpful :slightly_smiling_face:

Now the model training runs but i have some feeling that the model does not learn anything.
I have run a few tests with ~1000 examples from my generated data (~50 epochs) but the loss and accuracy stuck permanently at the same values with minimal difference (up and down) . Any ideas ?

Input Example after FeatureExtractor:
1

pretrained used:

# pretrained_part
    if pretrained_model == 'deit_tiny':
        pretrained_model = 'facebook/deit-tiny-patch16-224'
    elif pretrained_model == 'deit_small':
        pretrained_model = 'facebook/deit-small-patch16-224'
    elif pretrained_model == 'deit_tiny_distilled':
        pretrained_model = 'facebook/deit-tiny-distilled-patch16-224'
    elif pretrained_model == 'deit_small_distilled':
        pretrained_model = 'facebook/deit-small-distilled-patch16-224'
    else:
        pretrained_model = None  # use ViTModel from config

    feature_extractor = ViTFeatureExtractor(do_resize=True, size=224, do_normalize=True, image_mean=[0.5, 0.5, 0.5], image_std=[0.5, 0.5, 0.5])

Model:

class ViTSTR(nn.Module):
    def __init__(self, pretrained_model: str, num_labels: int):
        super(ViTSTR, self).__init__()
        self.pretrained_model = pretrained_model
        if self.pretrained_model:
            self.vit = ViTModel.from_pretrained(pretrained_model)
        else:
            config = ViTConfig(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, initializer_range=0.02, layer_norm_eps=1e-12, is_encoder_decoder=False, image_size=224, patch_size=16, num_channels=3)
            self.vit = ViTModel(config)
        self.num_labels = num_labels # len of numbers and symbols
        self.head = nn.Linear(self.vit.config.hidden_size, self.num_labels) if self.num_labels > 0 else nn.Identity()

    def forward(self, pixel_values, seqlen):
        outputs = self.vit(pixel_values=pixel_values)
        # only keep seqlen last hidden states
        hidden_states = outputs.last_hidden_state[:, :seqlen]

        # batch_size, seqlen, embedding size
        b, s, e = hidden_states.size()
        hidden_states = hidden_states.reshape(b*s, e)
        vocab = self.head(hidden_states).view(b, s, self.num_labels)
        return vocab

Model Output:
torch.Size([2, 34, 110]) [batch size, 32 max seq length + 2 ([s] [go] tokens), vocab size 108 + 2 ([s] [go] tokens)]

Loss (CrossEntropy) :

torch.Size([2, 34, 110])
tensor(4.8557, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.1618, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.4171, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.0631, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.4249, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.3074, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.0706, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(4.9988, grad_fn=<NllLossBackward>)
torch.Size([2, 34, 110])
tensor(5.0464, grad_fn=<NllLossBackward>)

(On 2xGPU i have tested a bigger set with batch size 96 for 50 epochs)
but there also after 1 epoch the loss is around ~5.xx and does not disincrease up to epoch 50

Some code:

def forward(self, pixel_values, labels=None, seqlen=32):
        prediction = self.vit_model(pixel_values, seqlen)
        return prediction
def training_step(self, batch, batch_idx):
        image_tensors = batch['pixel_values']
        labels = batch['text']
        target = self.converter.encode(labels).to(device=self.device)
        """
        size: [2, 34]
        tensor([[ 0, 51, 56, 47,  6,  3, 11, 11,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
        [ 0,  9, 47,  9,  8,  2, 47,  2,  1,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0]])
        """
        predictions = self(image_tensors, labels=target, seqlen=self.converter.batch_max_length)
        """
        size: [2, 34, 110]
        predictions.view(-1, predictions.shape[-1]) : [68, 110]   -> 2x34
        target.contiguous().view(-1) : [68, 110] -> 2x34
        """

        loss = self.loss_fn(predictions.view(-1, predictions.shape[-1]), target.contiguous().view(-1))
        return {"loss": loss, "train_acc": accuracy(predictions.permute(0, 2, 1), target).detach()}
def validation_step(self, batch, batch_idx):
        # TODO: DEBUG !!! Check if its work correctly
        image_tensors = batch['pixel_values']
        labels = batch['text']
        batch_size = image_tensors.size(0)
        self.length_of_data = self.length_of_data + batch_size
        target = self.converter.encode(labels).to(device=self.device)

        predictions = self(image_tensors, labels=target, seqlen=self.converter.batch_max_length)

        _, preds_index = predictions.topk(1, dim=-1, largest=True, sorted=True)
        preds_index = preds_index.view(-1, self.converter.batch_max_length)

        loss = self.loss_fn(predictions.contiguous().view(-1, predictions.shape[-1]), target.contiguous().view(-1))

        length_for_pred = torch.IntTensor([self.converter.batch_max_length - 1] * batch_size)
        preds_str = self.converter.decode(preds_index[:, 1:], length_for_pred)

        # compute word error rate
        word_error_rate = wer(predictions=preds_str, references=labels)

        preds_prob = F.softmax(predictions, dim=2)
        preds_max_prob, _ = preds_prob.max(dim=2)

        norm_ED = 0
        for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):
            pred_EOS = pred.find('[s]')
            pred = pred[:pred_EOS]  # prune after "end of sentence" token ([s])
            pred_max_prob = pred_max_prob[:pred_EOS]

            confidence = torch.sum(pred_max_prob) / len(pred_max_prob)

            # debug
            #print(f'ground_truth: {gt:25}  prediction: {pred:25s}  \nconfidence: {confidence}\n')

            if len(gt) == 0 or len(pred) ==0:
                norm_ED += 0
            elif len(gt) > len(pred):
                norm_ED += 1 - edit_distance(pred, gt) / len(gt)
            else:
                norm_ED += 1 - edit_distance(pred, gt) / len(pred)


        norm_ED = float(norm_ED / float(self.length_of_data)) # ICDAR2019 Normalized Edit Distance
        print(f'normalized edit distance is: {norm_ED}')

        return {"loss": loss, "val_acc": accuracy(predictions.permute(0, 2, 1), target).detach(), "word_error_rate": torch.tensor(word_error_rate).detach()}

Thanks again for your help and let me know if you need more informations :slight_smile:
PS: i have added you to the repo if you want to take a look:
Repository
→ Task_Experimential → Lightning OCR → transformer_based
Best regards
Felix