Training issue with the Transformer CAPTCHA recognition model: Unable to converge

I have built a model from scratch, inspired by the Transformer model and related code (such as ViT), with the goal of recognizing CAPTCHAs. However, during training, I’ve encountered an issue with the Transformer model. After several batch iterations, I consistently observe that the highest probability value in the output probability matrix is <EOS>, and this problem persists even after prolonged training.

Here is an overview of my approach: I initially followed the ViT approach, where I divide input images into many small patches. Each patch is then linearly mapped to a fixed emb_d dimension. For the decoder, I map the CAPTCHA letters to the same fixed emb_d values (note: the vocabulary includes digits and letters [0-9a-zA-Z]). This way, I construct an input sequence for the encoder.

For the encoder, I use the image patches as input and pass them through multiple encoder blocks, each consisting of multi-head self-attention layers, layer normalization, residual connections, and linear layers. Finally, the encoder’s output matches the input’s shape, i.e., [batch len_batch emb_d], and this output serves as both the key and value matrices for the decoder.

For the decoder, I use the target sequence (with a shape of [batch len_batch emb_d] and the last token removed) as input and set the target sequence (with the first token removed) as the actual target. I then compute the cross-entropy loss between the output and the target.

image

The issues I’ve identified are as follows: In the screenshots, it’s evident that after taking the argmax of the output probability matrix, it should yield the index of the predicted label (out), which ideally should match the target index (tgt). However, I’ve noticed that the output for ‘out’ consistently corresponds to index 1, indicating “<EOS>.”

You can find the code for this top get the errors of structures in the following location:

Jupyter Notebook Viewer

I have roughly verified the network structure and found no errors, but I remain uncertain. I hope someone can help me analyze this issue, and I would be extremely grateful for any assistance in resolving it.

I’m not an expert, but I think you should not put activation layer right in front of a softmax.

so this is what I suggest is the problem:

    self.dense_layer = nn.Linear(emb_d, vocab) 
    self.out_h = nn.GELU()
    self.softmax_layer = nn.LogSoftmax(dim=-1)

try without the that activation layer, or add another linear layer in-between.

First of all I must congratulate you for putting everything together and making it work. Kudos.

I’m a newbie, but here are my two cents:

Your loss seems way too high compared to models available on huggingface.

Supposing everything is right ( didn’t go through the code), the only issue I could see from your description is this:

You said you take the tokenized target sequence and remove the last token as input and first token removed as labels.

But AFAIK, the decoder input IDs, are the tokenized target labels (no tokens removed), shifted to the right by the decoder BOS token. This is then conditioned on the embeddings generated from the encoder to sequentially predict the next token (of course with the masking attention).

Hmm… Although this problem was solved a few days ago, I still want to share my experience:

  1. Check your network structure: In the example I provided, the word input sequence of the decoder was missing the position encoding information (although missing this information seems to be able to fit my network), and after adding this, the network is easy to fit
  2. The expressive power of the network is not easy to estimate for a novice like me, so considering that I don’t have a device like RTX4090 to help me quickly search, my suggestion is to transition from simple networks to complex networks and from small data to large data: This way we can quickly find structural problems in the model in the early stage and observe that it can neither fit on the training set nor on the test set. After a long time of parameter tuning, a simple network with 1 layer, 3 heads, and MLP of 72 dimensions can achieve 98% accuracy on 30*90 images.And I used to think that a much more complex network was needed.
  3. Use ray tune combined with cloud computing platforms such as kaggle for hyperparameter search (remember, otherwise you will waste a lot of time on tuning hyperparameters) Thank you for your suggestions
1 Like

Additionally, this is the final correct code for the Transformer architecture:

import torch
from matplotlib import pyplot as plt
from torch.nn import TransformerEncoder
from torchvision.transforms import transforms
from Trocr.Trans.Pos_Emb import *
from Trocr.Trans.Decoder import *
from Trocr.Trans.Encoder import *
from rich import console

console = console.Console()


class Transformer(nn.Module):
    def __init__(self, block_size, mlp_d, head_num, emb_d, dropout, in_channels, img_x, img_y, patch_d, vocab,
                 max_length):
        super(Transformer, self).__init__()
        self.max_length = max_length
        self.img_w = img_x
        self.img_h = img_y
        self.patch_d = patch_d
        self.in_channels = in_channels
        self.token_d = in_channels * (patch_d ** 2)
        self.token_num = (img_y // patch_d) * (img_x // patch_d)

        self.pos_emb = nn.Parameter(
            torch.randint(low=0, high=5, size=(self.token_num, emb_d), device=device).float())
        self.pos_emb2 = nn.Parameter(torch.randint(low=0, high=5, size=(max_length, emb_d), device=device).float())

        nn.init.xavier_normal_(self.pos_emb)
        nn.init.xavier_normal_(self.pos_emb2)
        # self.pos_emb = PositionalEncoding(d_model=emb_d, dropout=dropout, max_len=self.token_num)
        # self.cls_emb = nn.Parameter(torch.randn(1, 1, emb_d, device=device))
        # nn.init.xavier_normal_(self.pos_emb)
        # self.transformer_encoder = nn.TransformerEncoder(num_layers=block_size, norm=True)
        self.encoder = Encoder(block_size, mlp_d, head_num, emb_d, dropout)
        self.decoder = Decoder(block_size, mlp_d, head_num, emb_d, dropout)

        self.emb_layer = nn.Linear(self.token_d, emb_d)
        self.embbding_layper = nn.Embedding(vocab, emb_d)

        self.mask = nn.Transformer.generate_square_subsequent_mask(max_length, device)
        self.drop_layer = nn.Dropout(dropout)

        self.dense_layer = nn.Linear(emb_d, vocab)  # self.out_h = nn.LeakyReLU()

    def forward(self, words, img=None, enc_out=None, **kwargs):
        if img is not None:
            # console.log(img.shape)
            # B C H W
            img = rearrange(img, 'b c (patch_x x) (patch_y y) -> b (x y) (patch_x patch_y c)',
                            patch_x=self.patch_d,
                            patch_y=self.patch_d)
            # imm = transforms.ToPILImage()
            # plt.imshow(imm(img[0]))
            # plt.show()
            img = self.emb_layer(img)
            # x: [b t emb_d]
            batch, tokens, _ = img.shape
            # 分类+位置初始化编码
            # cls = repeat(self.cls_emb, 'b ...-> (b batches) ...', batches=batch)
            # img = torch.cat([cls, img], dim=1)
            # console.log(img.shape)
            # console.log(self.pos_emb.shape)
            img += self.pos_emb
            # img = self.pos_emb(img)
            img = self.drop_layer(img)

            enc_out = self.encoder(img)
        # enc_out = self.transformer_encoder(img)
        # console.log(enc_out[0][-1])
        emb = self.embbding_layper(words)
        emb += self.pos_emb2[:len(emb[0])]

        # mask = torch.triu(torch.ones(len(words[0]), len(words[0])), diagonal=1).bool().cuda()

        # console.log(mask)
        next_word = self.decoder(x=emb[:len(words)], enc_out=enc_out, mask=self.mask[:len(words[0]), :len(words[0])])
        # console.log(next_word[0][-1])
        # console.log(next_word.shape)
        out = self.dense_layer(next_word)

        # console.log(out[0])
        # console.log(torch.argmax(out[0], dim=-1))

        return out, enc_out


if __name__ == '__main__':
    vit = Transformer(block_size=2, mlp_d=36, head_num=2, emb_d=36, dropout=0.1, in_channels=1, img_x=30,
                      img_y=90, patch_d=3, vocab=12, max_length=10).to(device)
    console.log(summary(vit,
                        input_data=(torch.as_tensor(torch.randint(0, 10, (3, 5), device=device)).long(),
                                    torch.randn([3, 1, 30, 90], device=device)
                                    )))

    # pred = vit(torch.randn([3, 3, 90, 30], device=device), torch.randint(0, 64, (3, 1000), device=device))
    # tgt = torch.randint(0, 64, (3, 1000), device=device)
    # pred = rearrange(pred, 'b t d->b d t')
    # console.log(pred.shape)
    # console.log(tgt.shape)
    # loss = nn.CrossEntropyLoss()
    # print(loss(pred, tgt))
    #
    # x = vit(torch.randn([3, 3, 100, 100], device=device))
1 Like