Outputs change if re-using KVCache (past_key_values) for model.forward and generation

Problem Description

To feed a given sequence to a decoder-only model and obtain its past_key_values, I tried two approaches:

  1. Directly take the whole sequence as input_ids

  2. First feed part of the sequence and obtain its past_key_values, then feed the past_key_values and the rest of the sequence to the model

I expect the two approaches to be equivalent, however, the results seems to be different.

Requirements

transformers==4.48.0
torch==‘2.4.0+cu121’

code

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

# model_path = 'Qwen/Qwen2.5-7B-Instruct'
model_path = 'meta-llama/Llama-3.1-8B-Instruct'


prompt = "Instruction: Write a short story about a cat. The story should be 100 words long and include a happy ending."
response = "Once upon a time, there was a cat named Whiskers."


if __name__ == "__main__":
    model = AutoModelForCausalLM.from_pretrained(model_path).cuda()
    tokenizer = AutoTokenizer.from_pretrained(model_path)

    prompt_ids = tokenizer.apply_chat_template([{"role": "user", "content": prompt}], return_tensors='pt', add_generation_prompt=True).to(model.device)
    response_ids = tokenizer.encode(response, return_tensors='pt').to(model.device)
    
    prompt_cache = model(input_ids=prompt_ids, use_cache=True).past_key_values
    # forward with only input_ids
    cache1 = model(input_ids=torch.cat([prompt_ids, response_ids], dim=1), use_cache=True).past_key_values
    # forward with input_ids and past_key_values
    cache2 = model(input_ids=response_ids, past_key_values=prompt_cache, use_cache=True).past_key_values

    # compare the first layer's keys of cache1 and cache2, they are expected to be the same
    diff = (cache1[0][0] != cache2[0][0]).count_nonzero()
    print(f"{diff} out of {cache1[0][0].numel()} elements are different")

Output

For Qwen2.5-7B-Instruct:

4236 out of 35328 elements are different

For Llama-3.1-8B-Instruct:

13086 out of 75776 elements are different

Update

I added another example, based on the tutorial which shows how to re-use cache for model generation.

import copy
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache

#model_id = 'Qwen/Qwen2.5-7B-Instruct'
model_id = 'meta-llama/Llama-3.1-8B-Instruct'
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Init StaticCache with big enough max-length (1024 tokens for the below example)
# You can also init a DynamicCache, if that suits you better
# prompt_cache = StaticCache(config=model.config, max_batch_size=1, max_cache_len=1024, device="cuda", dtype=torch.bfloat16)
prompt_cache = DynamicCache()

INITIAL_PROMPT = "You are a helpful assistant."
inputs_initial_prompt = tokenizer(INITIAL_PROMPT, return_tensors="pt").to("cuda")
# This is the common prompt cached, we need to run forward without grad to be abel to copy
with torch.no_grad():
     prompt_cache = model(**inputs_initial_prompt, past_key_values = prompt_cache).past_key_values

prompts = ["Help me to write a blogpost about travelling.", "What is the capital of France?"]
for prompt in prompts:
    new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors="pt").to("cuda")
    past_key_values = copy.deepcopy(prompt_cache)
    outputs = model.generate(**new_inputs, past_key_values=past_key_values,max_new_tokens=20, do_sample=False)
    response = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]
    print(f"response with KVCache:\n{response}\n")


    outputs2 = model.generate(**new_inputs, max_new_tokens=20, do_sample=False)
    response2 = tokenizer.batch_decode(outputs2, skip_special_tokens=True)[0]
    print(f"response without KVCache:\n{response2}\n")

For Qwen2.5-7B-Instruct:

response with KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. What should I include? When writing a blog post about traveling, there are several key elements you can

response without KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. The title is "The Best Way to Travel: Solo or With Friends?"
Certainly! Here's a

response with KVCache:
You are a helpful assistant.What is the capital of France? No, Paris is the capital of France. The question is incomplete and does not provide an answer.

response without KVCache:
You are a helpful assistant.What is the capital of France? The capital of France is Paris.

For Llama3.1-8B-Instruct:

response with KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. I will give you the topic and you will write the top 5 things to do in that place

response without KVCache:
You are a helpful assistant.Help me to write a blogpost about travelling. Here is the my first paragraph:
"Travelling is a great way to broaden your horizons

response with KVCache:
You are a helpful assistant.What is the capital of France? Paris.
The capital of France is Paris. is the capital of France? Paris. The capital of

response without KVCache:
You are a helpful assistant.What is the capital of France? Paris.
What is the capital of France? Paris.
What is the capital of Australia? Canberra.

I use greedy decoding, and it seems that generating with KVCache yields different outputs. I think this is unexpected.

1 Like

I tried the code and got a warning, but I wonder if this is related to the change.

I am using the offcial cache implementation. Also, I added another example based on the code in this blog, the results are still unexpected.

1 Like

Could it be this…?
But this shouldn’t be related to Llama.

I have found a post that could explain this: Possible Bug with KV Caching in Llama (original) model · Issue #25420 · huggingface/transformers · GitHub

In short: Using KV cache will change the logits, especially when the model is loaded in 16-bit precision.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.