Initializing T5Encoder model

Hi, I want to get the final hidden state representations of few sentences. So, I decided to use the T5EncoderModel and my code from this example from huggingface: T5

When I initialize the T5Encoder model, I get a warning saying Some weights of T5EncoderModel were not initialized from the model checkpoint at t5-small and are newly initialized: ['encoder.embed_tokens.weight']). I am worried if this means that the embedding matrix of size (vocab_size, model_hidden_size) Is being newly initialized.

I want to get the best representations for sentences and if the embeddings were newly initialized, I guess the hidden_state representations will not be the best ones to use.

I want to know if:

  1. The above warning actually means new embedding matrix initialization
  2. Whats the best model/method to get sentence representations?

Thanks in advance.

@sgugger can you help me understand the issue?

Hi,

It looks to me like the issue is that when you initialize a T5EncoderModel using t5-small, the code doesn’t realize how to initialize the encoder’s embedding layer, as that layer is shared between the encoder and the decoder. In the code, it’s this variable. Because it’s shared between both the encoder and the decoder, it’s first initialized in T5Model, not inside the encoder or decoder. And so my guess is that inside the serialized t5-small model (the file on disk), the input embedding weights are only saved by referencing T5Model.shared, and not T5Model.encoder.shared. Then when you call T5EncoderModel.from_pretrained("t5-small"), the code looks for encoder.shared, but it doesn’t exist.

However, I noticed that T5EncoderModel has a set_input_embeddings method. This feels a little hacky, but it seems that at least one way to solve the problem is to do something like:

from transformers import T5Model, T5EncoderModel

model = T5Model.from_pretrained("t5-small")
encoder = T5EncoderModel.from_pretrained("t5-small")
encoder.set_input_embeddings(model.shared)
encoder.save_pretrained("encoder-checkpoint")

# Loads without any warnings
encoder = T5EncoderModel.from_pretrained("encoder-checkpoint")

I’d be curious to learn if there’s a less circuitous solution. For one, there might be better checkpoints out there to use than t5-small, that were intended specifically for a T5Encoder.

1 Like