What is the purpose of 'use_cache' in decoder?

Hi @lifelongeek!

The cache is only used for generation, not for training.

Say you have M input tokens and want to generate N out put tokens.

Without cache, the model computes the M hidden states for the input, then generates a first output token. Then, it computes the hidden state for the first generated token, and generates a second one. Then, it computes the hidden state for the first two generated tokens to generate the third one, and so on an so forth.

However, since the output side is auto-regressive, an output token hidden state remains the same once computed for every further generation step, so recomputing it every time we want to generate a new token seems wasteful.

With the cache, the model saves the hidden state once it has been computed, and only computes the one for the most recently generated output token at each time step, re-using the saved ones for hidden tokens. This reduces the generation complexity from O(n^3) to O(n^2) for a transformer model.

Hope that helps, let me know if you have any further questions!

7 Likes