What is the purpose of 'use_cache' in decoder?

I am using BartForConditionalGeneration for text summarization.
The ‘use_cache’ option is True by default when pre-training the Bart.

Purpose: The purpose of ‘use_cache’ option seems speed-up decoding.
(https://github.com/huggingface/transformers/blob/d822ab636b6a14ed50f7bca0797c1de42c19de61/src/transformers/modeling_bart.py#L120-L122)

Implementation: When ‘use_cache’ = True, the decoder use only the last time steps of input_ids & positional embedding.
(https://github.com/huggingface/transformers/blob/d822ab636b6a14ed50f7bca0797c1de42c19de61/src/transformers/modeling_bart.py#L551-L553)

Can anyone explain how above purpose & implementation are related?

Thank you

1 Like

Hi @lifelongeek!

The cache is only used for generation, not for training.

Say you have M input tokens and want to generate N out put tokens.

Without cache, the model computes the M hidden states for the input, then generates a first output token. Then, it computes the hidden state for the first generated token, and generates a second one. Then, it computes the hidden state for the first two generated tokens to generate the third one, and so on an so forth.

However, since the output side is auto-regressive, an output token hidden state remains the same once computed for every further generation step, so recomputing it every time we want to generate a new token seems wasteful.

With the cache, the model saves the hidden state once it has been computed, and only computes the one for the most recently generated output token at each time step, re-using the saved ones for hidden tokens. This reduces the generation complexity from O(n^3) to O(n^2) for a transformer model.

Hope that helps, let me know if you have any further questions!

7 Likes

Thank you for the clarification :grin:

Would this cache also be used if I call the generate method multiple times with the same conditional text as input?
I’d like to see the intermediate results of the prediction but I don’t want to calculate the hidden states unnecessarily many times.

1 Like

I have the same use case as @yjernite. I’ll need to call the generate multiple times with some shared input_ids. I’d like to cache these computations.
I’m looking into how to implement this. It would be great if Huggingface supports it.

1 Like

why is it that use_cache isn’t compatible with gradient checkpointing? use_cache is just for generation, and there are no gradients during generation. @ybelkada maybe, @muellerzr :pray: