This is w.r.t inference using models like mistral7b or llama.
In my understanding, KV cache size should grow as we process more tokens, however I see in the code that it shrinks as more tokens are processed. For example, in transformers/src/transformers/models/mistral/modeling_mistral.py, see the following code.
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
The cache past_key_value
instead of growing, shrinks in size. Initial size is [batch_size, num_heads, seq_len, head_dim]
and with increasing iterations, while batch_size, num_heads and head_dim
remain the same, seq_len
decreases. This results in a shrinking cache.
Can anyone explain why is the cache shrinking instead of growing in size?