Right, language models are trained like that – a training example is fed into the model and each token in the sequence is trying to predict the token that comes next. For example, if the training example is:
<s> The quick brown fox jumps over the lazy dog.</s>
The logits outputted for the <s>
token are to predict the word The
, the logits for The
are to predict the word quick
, and so on. But unlike generating a sequence at inference-time, this all happens with just one pass through the model. Every token in the training sequence is a classification problem to predict what should come next and each of these classification problems are run in parallel.
In contrast, when you do inference the prompt tokens you feed in are processed with one pass through the model but then each generated token after that point is created one by one by sampling the logits produced by the token that came right before.
So in practice, that is the reason why we need attention mask to ignore token 6 generation(to save computation),since we can ignore token 6, which we already have?
The attention mask is actually to make it so that inside the model, the hidden states for token 6 are only attending to the tokens that came before, not after. So if at inference-time you fed
<s> The quick brown fox jumps
into the model, the attention mask ensures The
only attends to the <s>
, quick
only attends to <s> The
, fox
only attends to <s> The quick brown
(not jumps
), etc. The way attention works is that each token should only see what came before it and so the mask ensures this. The way you can think of it is that if a token is the present, the previous tokens are the past and it attends to those (plus itself). But it’s not allowed to see the future, it’s just trying to predict the future.[1]
So going back to your example - while we can ignore the logits outputted for token 5 (because we already have token 6), we can’t actually ignore token 6 itself. It still needs to be fed into the model and processed by having it attend to the tokens that came before it. This will be a very hand wavy explanation, but in order for the model to predict token 12, it does need to have an internal “representation” of the full sequence that came before, including token 6 and the role 6 is playing within the sequence.
[1] Technically, not all architectures work this way, a lot of models (namely encoder models) do have bidirectional attention, where each token can see the previous tokens + subsequent tokens. But the vast majority of language models aren’t like this.