Questions about the shape of T5 logits

The shape of logits in output is (batch,sequence_length,vocab_size). I don’t understand the sequence_length part. I thought decoder should predict one word at a time and the logits should be (batch,vocab_size) .

Thank you in advance for any replies!

Hi,

Yes, but you always have a sequence length dimension. At the start of generation, we give the decoder start token to the T5 decoder.

Suppose you have trained a T5 model to translate language from English to French, and that we now want to test it on the English sentence “Welcome to Paris”. In that case, the T5 encoder will first encode this sentence, so we get last_hidden_states of the encoder of shape (batch_size, sequence_length, hidden_size). As we only have a single sentence here, batch_size = 1, for simplicity let’s assume that every word is a token and we don’t add special tokens, so sequence_length = 3, and the hidden_size of a T5 base model is 768. So the output of the encoder is of shape (1, 3, 768).

Next, we have the decoder. We first give it the config.decoder_start_token_id , which for T5 is equal to 0 (i.e. the id of the pad token = <pad>). This will be our only token at the beginning, hence sequence length = 1. So what we give as input (assuming we turned the decoder start token id into a vector) to the decoder is of shape (batch_size, sequence_length, hidden_size) = (1, 1, 768), and it will output the scores for each of the tokens of the vocabulary, hence shape (batch_size, sequence_length, vocab_size) = (1, 1, 32100). This will indicate which token T5 thinks will follow the pad token (so ideally it should output “Bienvenue”).

Next, we give <pad> Bienvenue as input to the decoder, so now our sequence length is 2. Assuming we have turned both tokens into a vector, the input to the decoder is now of shape (1, 2, 768). It will output a tensor of shape (1, 2, 32100). We are only interested in the logits for the last token, and we will take that as the prediction for the next token (ideally it should output “à”).

3 Likes

Thank you for replying. When I run the example in T5’doc, I get a logits with size(1,7,32128).


Does this mean that the decoder has forwarded 7 times during training(with beam search, topp sampling or greedy search)? If so, where does this happen in the source code?I haven’t found it. I know it will forward 7 times when calling model.generate but not training or computing loss.

Thanks again for your time!

Does this mean that the decoder has forwarded 7 times during training(with beam search, topp sampling or greedy search)?

No. During training, each training example is only forwarded once. We compare the outputs of the model (i.e. logits) to the labels, for all token positions.

Only at inference time, the generation happens sequentially (using .generate()), because then we feed the predictions of the model as input to the following

1 Like

I think I understand. For training, the decoder input is right shifted from labels whose sequence length
is 7. For inference, the decoder input is a cls token which need to forward several times.

Thank you for your time!