Past_key_value with multiple new tokens

I have a GPT-style model which I’m using to generate text for a set of prompts. Each prompt has a shared prefix, followed by a variable suffix. For example, something like:

Prompt 1: “A B C X Y”
Prompt 2: “A B C W V”
Prompt 3: “A B C Q R S”

A simple way of processing Prompt 1 using key-value caching is like this:

abcxy_tokens = tokenize("A B C X Y")
result = model(abcxy_tokens,use_cache=True)
past_key_values = result['past_key_values']
predicted_token = argmax(result['logits'])
result = model(predicted_token,use_cache=True,past_key_values=past_key_values)
...(repeat)

This works great for me, but I would like to save more compute by caching the key-values for the shared prefix, “A B C”. I want to do something like this:

abc_tokens = tokenize("A B C")
prefix_key_values = model(abcxy_tokens,use_cache=True)['past_key_values']
#prefix_key_values can now be re-used for prompt 1, prompt 2, etc.

prompt_tokens = tokenize("X Y")
result = model(prompt_tokens,use_cache=True,past_key_values=prefix_key_values)
...

However, when I try this, I get an error complaining about a shape mismatch. The code only works if there is a single new token.

I don’t think that there’s any fundamental reason why key-value caching would not be possible with multiple new tokens. Does the existing implementation silently assume a single new token? Is there an alternate method I should be using for this scenario?

This should be possible and in fact I’ve done something very similar before. Could you share the error message, stack trace, and which Huggingface model class you’re using?