Storing and loading KV cache

Hello,

is there an explicit way to store and later on load KV cache in the models?

Thanks!

Hey!

You can reuse a cache object in the next generation steps as follows:

out = model.generate(input_ids, use_cache=True, return_dict_in_generate=True)
past_key_values = out.past_key_values
generated_ids = out.sequences

# Now we can continue generation using cache and already generated tokens
out_continued = model.generate(generated_ids, past_key_values=past_key_values, return_dict_in_generate=True)
continued_generated_ids = out_continued.sequences

If you want to save the cache and load it back, we don’t have an explicit way for that. But you can try to save it by saving keys and values, where each of them is a tuple of tensors

keys, values = past_key_values.key_cache, past_key_values.value_cache
torch.save(keys, "keys.pt")
torch.save(values, "values.pt")

# Later you can load it as follows assuming you used the default DynamicCache
from transformers import DynamicCache

past_key_values = DynamicCache()
past_key_values.key_cache = torch.load("keys.pt")
past_key_values.value_cache =values = torch.load("values.pt")

Btw, can you share the use case when saving and loading cache in needed. We are now trying to make a unified API for all cache objects, and it will help us to understand common use cases

1 Like

Thanks for the answer @RaushanTurganbay !
I will start with that and try it out.

Regarding use case -
The use case is that we have a lot of long context examples that we repeatedly query, and so it doesn’t make sense to recalculate everything for every query. It makes more sense to reload the KV values for every query.
Does that make sense?

I see, thanks for explaining. We are making a cache that will inherit from torch Module here, so after the PR is merged you should be able to copy cache or save with torch.save

That’s great, any estimate when this PR will be merged?