Past_key_value with multiple new tokens

I have a GPT-style model which I’m using to generate text for a set of prompts. Each prompt has a shared prefix, followed by a variable suffix. For example, something like:

Prompt 1: “A B C X Y”
Prompt 2: “A B C W V”
Prompt 3: “A B C Q R S”

A simple way of processing Prompt 1 using key-value caching is like this:

abcxy_tokens = tokenize("A B C X Y")
result = model(abcxy_tokens,use_cache=True)
past_key_values = result['past_key_values']
predicted_token = argmax(result['logits'])
result = model(predicted_token,use_cache=True,past_key_values=past_key_values)
...(repeat)

This works great for me, but I would like to save more compute by caching the key-values for the shared prefix, “A B C”. I want to do something like this:

abc_tokens = tokenize("A B C")
prefix_key_values = model(abcxy_tokens,use_cache=True)['past_key_values']
#prefix_key_values can now be re-used for prompt 1, prompt 2, etc.

prompt_tokens = tokenize("X Y")
result = model(prompt_tokens,use_cache=True,past_key_values=prefix_key_values)
...

However, when I try this, I get an error complaining about a shape mismatch. The code only works if there is a single new token.

I don’t think that there’s any fundamental reason why key-value caching would not be possible with multiple new tokens. Does the existing implementation silently assume a single new token? Is there an alternate method I should be using for this scenario?

1 Like

This should be possible and in fact I’ve done something very similar before. Could you share the error message, stack trace, and which Huggingface model class you’re using?