How is padding masking considered in the Attention Head of a Transformer?

For purely educational purposes, my goal is to implement basic Transformer architecture from scratch. So far I focused on the encoder for classification tasks and assumed that all samples in a batch have the same length. This means, I didn’t care about any masking.

However, now I want to support masking. I like to think that I understand the the purpose of, e.g., the target mask so the order cannot “peek into the future”. I generate this mask as follows:

source_batch = torch.LongTensor([
    [1, 2, 3, 0, 0, 0],
    [1, 2, 3, 4, 5, 6],
    [1, 2, 3, 4, 5, 0]
])

batch_size, seq_len = source_batch.shape

def generate_tgt_mask(size):
    return torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)

print(generate_tgt_mask(seq_len))

yielding:

tensor([[0., -inf, -inf, -inf, -inf, -inf],
        [0.,   0., -inf, -inf, -inf, -inf],
        [0.,   0.,   0., -inf, -inf, -inf],
        [0.,   0.,   0.,   0., -inf, -inf],
        [0.,   0.,   0.,   0.,   0., -inf],
        [0.,   0.,   0.,   0.,   0.,   0.]])

which should be the expected outcome when I check the PyTorch docs. This mask has a shape of (L,L) where L is the sequence length of the source or target sequence. Again, this matches the docs.

I use this mask in my implementation of the Scaled Dot Product Attention as follows – which should be in line with many other implementations I’ve seen:

class Attention(nn.Module):
    ### Implements Scaled Dot Product Attention
    
    def __init__(self):
        super().__init__()


    def forward(self, Q, K, V, mask=None, dropout=None):
        # All shapes: (batch_size, seq_len, hidden_size)
        
        # Perform Q*K^T (* is the dot product here)
        # We have to use torch.matmul since we work with batches!
        out = torch.matmul(Q, K.transpose(1, 2)) # => shape: (B, L, L)

        # Divide by scaling factor
        out = out / (Q.shape[-1] ** 0.5)

        # Optional: src_mask/tgt_mask (shape: (L, L); mask values are represented by -inf)
        if mask is not None:
            out += mask.unsqueeze(0) # Broadcast since it's the same mask for all samples in batch
        
        # Push throught softmax layer
        out = f.softmax(out, dim=-1)
        
        # Optional: Dropout
        if dropout is not None:
            out = nn.Dropout(out, dropout)
        
        # Multiply with values V
        out = torch.matmul(out, V)
        
        return out

So far so good…at least I like to think. However, my problem is now the mask to address the padding (e.g. src_key_padding_mask). From different tutorials using the nn.Transformer, this mask can be generated as follows:

pad_token_index = 0

src_key_padding_mask = (source_batch != pad_token_index)

print(src_key_padding_mask)

yielding:

tensor([[ True,  True,  True, False, False, False],
        [ True,  True,  True,  True,  True,  True],
        [ True,  True,  True,  True,  True, False]])

having shape of (N,L) which again matches the doc.

What I’m now missing is: How do I have to incorporate this matrix into my implementation of Attention?

Intuitively, I would assume that the masking matrix would contain -inf for each position associated the a padding. For example, looking at the first sequence in my example batch above, I would assume the masking matrix to look like:

tensor([[0.,   0.,   0.,   -inf, -inf, -inf],
        [0.,   0.,   0.,   -inf, -inf, -inf],
        [0.,   0.,   0.,   -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf],
        [-inf, -inf, -inf, -inf, -inf, -inf]])

And indeed, some – but not all – example code that implement the Transformer archictectur from scratch, create the masking matrix for the padding like this. Applying this matrix to the scores obviously also sets the scores to 0, that is, the last 3 rows are all 0.

However, once pushed throught Softmax, the last 3 rows now all contain the value 1/6. For example, for the source_batch above I get

tensor([[[0.1989, 0.4297, 0.3714, 0.0000, 0.0000, 0.0000],
         [0.4334, 0.2225, 0.3440, 0.0000, 0.0000, 0.0000],
         [0.2880, 0.2284, 0.4836, 0.0000, 0.0000, 0.0000],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667],
         [0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667]],
       ...
       (the other 2 samples of the batch are not shown)

What am I missing here? I’m pretty sure it’s something trivial, but I just can’t see it right now.