(first token generation puzzle)Why does transformers take the last dimension as output when generating the first token in language generation process?

here in this line:

When I am using llama2-7b to generate languages, my prompt has 8 words and I get the outputs.logits tensor with shape [1,11,32000] (which means 11 tokens) and this line of code gets next_token_logits tensor with shape [1,32000], as the next token embedding.

My question is , why does this line of code just pick the last dimension(-1) of 11 from [1,11,32000] as the predicted token?

(I am new here, and this is somewhat confusing to me, since in my expectation, the LLM should predict something like [1,12,32000] given [1,11,32000], and pick the last dimension of 12, not like what I described above)

The logits outputted at each position are for predicting what token should come next after that token. If you input 11 tokens, the logits outputted at position 11 are predictions for what the 12th token should be (and this is true for the previous tokens too; the logits outputted at position 5 are predictions for what token 6 should be, but obviously you already have token 6, so we ignore those logits). To actually generate the 12th token, we convert the logits to probabilities and sample from that distribution to pick what should come after the 11th token. That token will then get fed into the model, leading to predictions for what the 13th token should be. And so on.

(I am new here, and this is somewhat confusing to me, since in my expectation, the LLM should predict something like [1,12,32000] given [1,11,32000], and pick the last dimension of 12, not like what I described above)

The reason it doesn’t work like this is because at this point, the 12th token doesn’t exist yet. It’ll only exist after we’ve sampled a new token after obtaining the logits from the 11th token (predictions for what the 12th token should be).

1 Like

Thanks for your kind reply!
Besides, LLMs are predicting what token should come next after that token----Is that because the LLMs are trained to do like this(for instance, the training data is predicting the next word)?

If LLMs are trained to do like this, since you mentioned
" the logits outputted at position 5 are predictions for what token 6 should be, but obviously you already have token 6, so we ignore those logits"

So in practice, that is the reason why we need attention mask to ignore token 6 generation(to save computation),since we can ignore token 6, which we already have?

Right, language models are trained like that – a training example is fed into the model and each token in the sequence is trying to predict the token that comes next. For example, if the training example is:

<s> The quick brown fox jumps over the lazy dog.</s>

The logits outputted for the <s> token are to predict the word The, the logits for The are to predict the word quick, and so on. But unlike generating a sequence at inference-time, this all happens with just one pass through the model. Every token in the training sequence is a classification problem to predict what should come next and each of these classification problems are run in parallel.

In contrast, when you do inference the prompt tokens you feed in are processed with one pass through the model but then each generated token after that point is created one by one by sampling the logits produced by the token that came right before.

So in practice, that is the reason why we need attention mask to ignore token 6 generation(to save computation),since we can ignore token 6, which we already have?

The attention mask is actually to make it so that inside the model, the hidden states for token 6 are only attending to the tokens that came before, not after. So if at inference-time you fed

<s> The quick brown fox jumps

into the model, the attention mask ensures The only attends to the <s>, quick only attends to <s> The, fox only attends to <s> The quick brown (not jumps), etc. The way attention works is that each token should only see what came before it and so the mask ensures this. The way you can think of it is that if a token is the present, the previous tokens are the past and it attends to those (plus itself). But it’s not allowed to see the future, it’s just trying to predict the future.[1]

So going back to your example - while we can ignore the logits outputted for token 5 (because we already have token 6), we can’t actually ignore token 6 itself. It still needs to be fed into the model and processed by having it attend to the tokens that came before it. This will be a very hand wavy explanation, but in order for the model to predict token 12, it does need to have an internal “representation” of the full sequence that came before, including token 6 and the role 6 is playing within the sequence.

[1] Technically, not all architectures work this way, a lot of models (namely encoder models) do have bidirectional attention, where each token can see the previous tokens + subsequent tokens. But the vast majority of language models aren’t like this.

1 Like

Thanks for your kind reply!

So according to your descriptions, it seems that the attention mask is acutually not needed during inference, but only needed during training?-----

because during inference even if we want to “attend” the “future”, there is no future for us to attend.

If attention mask is not needed during inference, why is it included in LLM inference code, including huggingface transformers lib?

When you provide a prompt during inference, the attention mask is needed for that for the first forward pass when all the prompt tokens go through the model. If your prompt is:

<s> The quick brown fox

Then the “future” for quick is brown fox so it shouldn’t attend to those two tokens. But it should attend to <s> The. After that though when the model is generating one token at a time, there’s no need for an attention mask.

1 Like