Understanding BertLMPredictionHead

Hey there! :wave:

I am currently trying to understand how some of the transformer models work and start by focussing on BERT. Because I am trying to figure stuff out I highly appreciate corrections regarding any assumptions I state in this post!

Mapping the output embeddings back to the initial tokens is of special interest to me - a task which is done by the MLM head:

class BertLMPredictionHead(nn.Module):
def __init__(self, config):
    super(BertLMPredictionHead, self).__init__()
    self.transform = BertPredictionHeadTransform(config)

    # The output weights are the same as the input embeddings, but there is
    # an output-only bias for each token.
    self.decoder = nn.Linear(config.hidden_size,
                             config.vocab_size,
                             bias=False)

    self.bias = nn.Parameter(torch.zeros(config.vocab_size))

def forward(self, hidden_states):
    hidden_states = self.transform(hidden_states)
    hidden_states = self.decoder(hidden_states) + self.bias
    return hidden_states

So the hidden states processed by the head will first be transformed by BertPredictionHeadTransform class and then fed to a linear layer with an “external” bias. The transform() operaton simply performs a linear transformation but keeps the input shape and then applies an activation function + layernorm:

class BertPredictionHeadTransform(nn.Module):
def __init__(self, config):
    super(BertPredictionHeadTransform, self).__init__()
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
        self.transform_act_fn = ACT2FN[config.hidden_act]
    else:
        self.transform_act_fn = config.hidden_act
    self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)

def forward(self, hidden_states):
    hidden_states = self.dense(hidden_states)
    hidden_states = self.transform_act_fn(hidden_states)
    hidden_states = self.LayerNorm(hidden_states)
    return hidden_states

So far so good (unless I got something wrong!) - what I don’t understand yet is how the weights for the decoder are determined. According to the comment above setting the decoder the input emeddings are used as weights of the linear layer. This seems to be enforced by the following code:

    def tie_weights(self):
    """ Make sure we are sharing the input and output embeddings.
        Export to TorchScript can't handle parameter sharing so we are cloning them instead.
    """
    self._tie_or_clone_weights(self.cls.predictions.decoder,
                               self.bert.embeddings.word_embeddings)

This works because nn.Linear of pytorch by default transposes its weight matrix and so the shapes work out, correct? But the cloning of the weights is just some sort of initialization and they are still further trained (together with the bias) during the pretraining MLM task, right?

So what I am wondering now is: Is there a special reasoning for cloning the weights? Has this also been done in the original BERT model and did they describe or explain it somewhere? Any information or feedback highly appreciated! :upside_down_face:

This is not something unique to BERT but actually an artefact from the original Transformer. I remember reading about it in their paper but before that paper, tying input and output embeddings was proposed in Press and Wolf (2017). From the abstract:

Finally, we show that weight tying can reduce the size of neural translation models to less than half of their original size without harming their performance.

1 Like

This actually predates Transformers models and was often done in LSTMs as well. See also the AWD-LSTM paper or instance.

1 Like

Awesome! I must have missed that part in the original transformer paper and didn’t imagine it could be a widely used technique. I am going to read up about it asap! Thanks both of you, appreciate it very much :slight_smile: