Hey there!
I am currently trying to understand how some of the transformer models work and start by focussing on BERT. Because I am trying to figure stuff out I highly appreciate corrections regarding any assumptions I state in this post!
Mapping the output embeddings back to the initial tokens is of special interest to me - a task which is done by the MLM head:
class BertLMPredictionHead(nn.Module):
def __init__(self, config):
super(BertLMPredictionHead, self).__init__()
self.transform = BertPredictionHeadTransform(config)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(config.hidden_size,
config.vocab_size,
bias=False)
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states) + self.bias
return hidden_states
So the hidden states processed by the head will first be transformed by BertPredictionHeadTransform class and then fed to a linear layer with an “external” bias. The transform() operaton simply performs a linear transformation but keeps the input shape and then applies an activation function + layernorm:
class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
super(BertPredictionHeadTransform, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
self.transform_act_fn = ACT2FN[config.hidden_act]
else:
self.transform_act_fn = config.hidden_act
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
So far so good (unless I got something wrong!) - what I don’t understand yet is how the weights for the decoder are determined. According to the comment above setting the decoder the input emeddings are used as weights of the linear layer. This seems to be enforced by the following code:
def tie_weights(self):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self._tie_or_clone_weights(self.cls.predictions.decoder,
self.bert.embeddings.word_embeddings)
This works because nn.Linear of pytorch by default transposes its weight matrix and so the shapes work out, correct? But the cloning of the weights is just some sort of initialization and they are still further trained (together with the bias) during the pretraining MLM task, right?
So what I am wondering now is: Is there a special reasoning for cloning the weights? Has this also been done in the original BERT model and did they describe or explain it somewhere? Any information or feedback highly appreciated!