Past_key_values - why not past_key_values_queries?

My understanding is that when passed a sequence of input vectors, a transformer self-attention block computes three different transformed versions of that sequence: the keys, the queries, and the values. Then it takes the key/query dot products, softmaxes, and takes a weighted average of the values.

How is it that when we pass past_key_values to a huggingface transformer, we don’t have to pass the queries as well? Since we don’t pass in the tokens (input_ids) for the previous values, there would be no way for the model to re-compute the queries, so how can it possibly run the forward pass?

1 Like

Hey @j3m :wave: When you’re doing auto-regressive text generation, you predict one token at a time. When predicting a given token, in the attention layer, you need to compute the attention between the most recent token and all tokens generated so far – you use the query from the last token, but the key and the value from all tokens generated so far. This means you have no benefit in caching the query, but you save a few computations if you cache the key and the value :slight_smile:

For further reference, check the Masked Self-Attention section of The Illustrated GPT-2 :hugs:

6 Likes

Are you saying that during the decoding step, the key/value and query vectors do not share the same sequential length dimension? This would mean that at every step of decoding the Query vector would have dimensions [Batch, 1, EmbeddingLength] while the Key&Value vector would have dimensions [Batch, 1, SeqLenUpTillCurrentPredictedToken]. Which honestly makes a lot of sense, but I guess I’ve gotten different impressions from different sources online.
Even in the illustrated GPT-2 at this part I’m given the impression that when he is generating the KVQ vectors for the second word, he states that they can reuse the Q vector of the first word, implying that the KVQ vectors would share the same sequence length dimension.

The key/value and query vectors do not have the same sequence length: you need the key/value for each each token in the sequence, whereas you only need the query of the last token. Due to the mask attention, there’s no point in recalculating the q, k and v for the previous tokens – you’ll get exactly the same q, k and v in each decoder since those tokens cannot see the last token that has been added to the sequence.

But why do we just need the ks and vs and not the qs of the previous tokens? To answer this, think about why you need the qs in the first place: the q of a token “observes” the rest of the tokens to compute the attention score (q*k). Due to the casual masking of the tokens at the right, this computation for each decoder would be identical no matter how many tokens you add at the right side. Consequently, the updated embedding for that token won’t change after the decoder, and the k and v for the next decoder will be the same as in the previous inference call to the model. So you can just cache the k and v and not deal with the qs since the embeddings are never going to change, due to the casual masking.

A good article adding more detail on top of the illustrated GPT-2 is Speeding up the GPT - KV cache | Becoming The Unbeatable. What happens is that, when sampling the next token, you generate the new key, value and query corresponding to the last-sampled token. Then you append that new key and value to the cache of the previous keys and values. As a result, you end up with a value and key for each token in the sequence. For the query, you only need the one of the last-sampled token, so there’s no need to cache the previous ones.

1 Like

(post deleted by author)

Read this famous article and find the following sentence

Notice that the second path is the only one that’s active in this calculation. Each layer of GPT-2 has retained its own interpretation of the first token and will use it in processing the second token

It explains everything.

The key is “masked self-attention”.

If you use “traditional self-attention” (which is wrong in practice), you can cache Q, K, V for all tokens at Layer 1, as you mentioned in your question. But from Layer 2 on, everything is recomputed so there is no chance of cache.