Hi.
Backstory: I tried to visualize some static BERT embeddings, before the first transformer block, and was wondering if I should average them. But then what about the different sized inputs? I suspected that the embeddings for the padding token would be zero and so I could just average them all.
The Problem: While for a plain-vanilla PyTorch embedding this seems to be the case, for BERT it is not. There the embeddings are not zero. The start out as PyTorch embeddings, but end up with non-zero values. Examples below.
I can live with that, as I found a way to work around this, but I was wondering if there is something (subtle) to be learned hear? Why is the embedding not zero. How could there be a backprop going through to the embeddings that should not be accessible, masked out, anyway?
Plain-vanilla PyTorch:
emb = torch.nn.Embedding(5, 3, padding_idx=0) # Embedding(5, 3, padding_idx=0)
inp = torch.tensor([1,2,3,0,0,0,0])
emb(inp)
=>
tensor([[-0.9628, 0.4631, -0.1923],
[ 0.7668, 0.0380, -1.1776],
[ 0.0938, 0.9070, 0.5080],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000],
[ 0.0000, 0.0000, 0.0000]], grad_fn=<EmbeddingBackward0>)
BERT:
bert = AutoModelForSequenceClassification.from_pretrained('bert-base-uncased')
bert.bert.embeddings.word_embeddings # Embedding(30522, 768, padding_idx=0)
inp = torch.tensor([ 101, 2339, 2079, 2111, 2224, 1012, 4372, 2615, 5371, 2006,
8241, 1029, 2339, 2079, 2111, 2404, 1037, 1012, 4372, 2615,
5371, 2000, 3573, 2035, 2037, 7800, 1999, 1037, 8241, 1029,
2065, 2619, 20578, 2015, 2009, 1025, 3475, 1005, 1056, 1996,
1012, 4372, 2615, 8053, 7801, 2004, 2035, 1996, 2060, 6764,
1029, 4283, 999, 102, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0], dtype=torch.long)
bert.bert.embeddings.word_embeddings(inp)
=>
tensor([[ 0.0138, -0.0260, -0.0237, ..., 0.0090, 0.0067, 0.0144],
[ 0.0214, 0.0150, -0.0675, ..., 0.0120, 0.0240, 0.0203],
[ 0.0066, -0.0536, 0.0063, ..., -0.0139, -0.0531, -0.0086],
...,
[-0.0102, -0.0614, -0.0264, ..., -0.0198, -0.0371, -0.0097],
[-0.0102, -0.0614, -0.0264, ..., -0.0198, -0.0371, -0.0097],
[-0.0102, -0.0614, -0.0264, ..., -0.0198, -0.0371, -0.0097]],
grad_fn=<EmbeddingBackward0>)