Past_key_values - why not past_key_values_queries?

The key/value and query vectors do not have the same sequence length: you need the key/value for each each token in the sequence, whereas you only need the query of the last token. Due to the mask attention, there’s no point in recalculating the q, k and v for the previous tokens – you’ll get exactly the same q, k and v in each decoder since those tokens cannot see the last token that has been added to the sequence.

But why do we just need the ks and vs and not the qs of the previous tokens? To answer this, think about why you need the qs in the first place: the q of a token “observes” the rest of the tokens to compute the attention score (q*k). Due to the casual masking of the tokens at the right, this computation for each decoder would be identical no matter how many tokens you add at the right side. Consequently, the updated embedding for that token won’t change after the decoder, and the k and v for the next decoder will be the same as in the previous inference call to the model. So you can just cache the k and v and not deal with the qs since the embeddings are never going to change, due to the casual masking.

A good article adding more detail on top of the illustrated GPT-2 is Speeding up the GPT - KV cache | Becoming The Unbeatable. What happens is that, when sampling the next token, you generate the new key, value and query corresponding to the last-sampled token. Then you append that new key and value to the cache of the previous keys and values. As a result, you end up with a value and key for each token in the sequence. For the query, you only need the one of the last-sampled token, so there’s no need to cache the previous ones.

1 Like