Weight tying means the classifier weights are referencing the embedding weights, making them the exact same weight tensor: if one changes, the other one changes as well, they are literally referencing the same memory location. This is used to both save in parameters, and research shows it improves model performance (see Weight Tying Explained | Papers With Code).
Regarding the code, let’s take a walk through it:
In modeling_bert.py
we see the following:
class BertForPreTraining(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.bert = BertModel(config)
self.cls = BertPreTrainingHeads(config)
# Initialize weights and apply final processing
self.post_init()
BertForPreTraining
inherits BertPreTrainedModel
, which inherits PreTrainedModel
in modeling_utils.py
. This is where post_init()
is defined, and it looks like this:
def post_init(self):
"""
A method executed at the end of each Transformer model initialization, to execute code that needs the model's
modules properly initialized (such as weight initialization).
"""
self.init_weights()
self._backward_compatibility_gradient_checkpointing()
Now, let’s look at self.init_weights()
:
def init_weights(self):
"""
If needed prunes and maybe initializes weights.
"""
# Prune heads if needed
if self.config.pruned_heads:
self.prune_heads(self.config.pruned_heads)
if _init_weights:
# Initialize weights
self.apply(self._init_weights)
# Tie weights should be skipped when not initializing all weights
# since from_pretrained(...) calls tie weights anyways
self.tie_weights()
And finally, self.tie_weights()
eventually does output_embeddings.weight = input_embeddings.weight
.