Problem Description
To feed a given sequence to a decoder-only model and obtain its past_key_values, I tried two approaches:
-
Directly take the whole sequence as input_ids
-
First feed part of the sequence and obtain its past_key_values, then feed the past_key_values and the rest of the sequence to the model
I expect the two approaches to be equivalent, however, the results seems to be different.
Requirements
transformers==4.48.0
torch==‘2.4.0+cu121’
code
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
# model_path = 'Qwen/Qwen2.5-7B-Instruct'
model_path = 'meta-llama/Llama-3.1-8B-Instruct'
prompt = "Instruction: Write a short story about a cat. The story should be 100 words long and include a happy ending."
response = "Once upon a time, there was a cat named Whiskers."
if __name__ == "__main__":
model = AutoModelForCausalLM.from_pretrained(model_path).cuda()
tokenizer = AutoTokenizer.from_pretrained(model_path)
prompt_ids = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], return_tensors='pt', add_generation_prompt=True).to(model.device)
response_ids = tokenizer.encode(response, return_tensors='pt').to(model.device)
prompt_cache = model(input_ids=prompt_ids, use_cache=True).past_key_values
# forward with only input_ids
cache1 = model(input_ids=torch.cat([prompt_ids, response_ids], dim=1), use_cache=True).past_key_values
# forward with input_ids and past_key_values
cache2 = model(input_ids=response_ids, past_key_values=prompt_cache, use_cache=True).past_key_values
# compare the first layer's keys of cache1 and cache2, they are expected to be the same
diff = (cache1[0][0] != cache2[0][0]).count_nonzero()
print(f"{diff} out of {cache1[0][0].numel()} elements are different")
Output
For Qwen2.5-7B-Instruct:
4236 out of 35328 elements are different
For Llama-3.1-8B-Instruct:
13086 out of 75776 elements are different
Update
I added another example, based on the tutorial which shows how to re-use cache for model generation.
import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
#model_id = 'Qwen/Qwen2.5-7B-Instruct'
model_id = 'meta-llama/Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)
# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
# prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
prompt_cache = DynamicCache()
INITIAL_PROMPT = "You are a helpful assistant."
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be abel to copy
with torch.no_grad():
prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
for prompt in prompts:
new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
past_key_values = copy.deepcopy(prompt_cache)
outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20, do_sample=False)
response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
print(f"response with KVCache:\n{response}\n")
outputs2 = model.generate(**new_inputs, max_new_tokens=20, do_sample=False)
response2 = tokenizer.batch_decode(outputs2, skip_special_tokens=True)[0]
print(f"response without KVCache:\n{response2}\n")
For Qwen2.5-7B-Instruct:
response with KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. What should I include? When writing a blog post about traveling, there are several key elements you can
response without KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. The title is "The Best Way to Travel: Solo or With Friends?"
Certainly! Here's a
response with KVCache:
You are a helpful assistant.What is the capital of France? No, Paris is the capital of France. The question is incomplete and does not provide an answer.
response without KVCache:
You are a helpful assistant.What is the capital of France? The capital of France is Paris.
For Llama3.1-8B-Instruct:
response with KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. I will give you the topic and you will write the top 5 things to do in that place
response without KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. Here is the my first paragraph:
"Travelling is a great way to broaden your horizons
response with KVCache:
You are a helpful assistant.What is the capital of France? Paris.
The capital of France is Paris. is the capital of France? Paris. The capital of
response without KVCache:
You are a helpful assistant.What is the capital of France? Paris.
What is the capital of France? Paris.
What is the capital of Australia? Canberra.
I use greedy decoding, and it seems that generating with KVCache yields different outputs. I think this is unexpected.