Hi,
for some models there is this tie_word_embeddings
parameter. I think it is for the text 2 text models.
Can someone please explain what exactly this parameter is doing?
Many thanks
Philip
No this is for all models that have a language modeling head (so even masked language models like BERT or causal language models like GPT-2). The idea is that the embedding weights (vocab_size by hidden_size) are tied with the decoder (hidden_size by vocab_size) so the model only learns one representation of the words (that is a big matrix!)
6 Likes
Arij
July 25, 2022, 12:04pm
3
Excuse me maybe I missunderstand something but according to this line:
def _init_weights(self, module):
"""Initialize the weights"""
factor = self.config.initializer_factor # Used for testing weights initialization
if isinstance(module, T5LayerNorm):
module.weight.data.fill_(factor * 1.0)
elif isinstance(module, (T5Model, T5ForConditionalGeneration, T5EncoderModel)):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
elif isinstance(module, T5DenseActDense):
# Mesh TensorFlow FF initialization
# See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
# and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi, "bias") and module.wi.bias is not None:
module.wi.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
what is the relationship between initializing the language modeling head and tying in this case?
I have just finish reading this issue after that during research found your discussion
opened 02:11PM - 13 Apr 22 UTC
closed 01:08PM - 18 May 22 UTC
- `transformers` version: 4.18.0, master branch
### Who can help
@patrickvon… platen
I found some significant differences in weight init between the PT and TF implementations of T5.
The **embeddings** (model.shared):
- In PT, according to `T5PreTrainedModel._init_weights`, they are initialized with random normal with std=1.0:
`module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)`
- In TF (TFT5Model), the embeddings are initialized as such:
`self.shared = TFSharedEmbeddings(config.vocab_size, config.d_model, name="shared")`
Since initializer_range is not being provided, it is using the default, which is `hidden_size**-0.5` (see TFSharedEmbeddings).
This means that in the base model (d=768), the weights in PT are being initialized with **stdev=1.0**, and in TF they are being initialized with **stdev=0.036**.
The **LM head** (model.lm_head):
- In PT, the initializer is not specified, meaning it is being initialized with a uniform distribution in [-sqrt(1/d_model), sqrt(1/d_model)] (https://pytorch.org/docs/stable/generated/torch.nn.Linear.html). The weights don't seem to be initialized in _init_weights either.
`lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)`
- In TF, the initializer is explicitly provided (TFT5ForConditionalGeneration):
`lm_head_initializer = tf.keras.initializers.RandomNormal(mean=0, stddev=config.initializer_factor)`
So, in the base model, the weights in PT are initialized with a uniform distribution of **[-0.036, 0.036]**, and in TF they are initialized with a random normal with **stdev=1.0**.
I'm not entirely sure about the actual implications of this in model training. But at least the lm_head weights will have a huge impact in loss values initially.
Based on other transformer models I've seen, the "correct" answer seems to be that both weights should be initialised with stdev=1.0. But none of the implementations actually does this.
Is Embedding*Embedding^{T} = I necessarily true-ish?