KV caching for varying length texts

I am trying to do some strucutured text extraction using some kv caching tricks. For this example I will use the following model and data:

model_name = "Qwen/Qwen2.5-0.5B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# data
text = """We introduce Mistral 7B, a 7–billion-parameter language model engineered for
superior performance and efficiency. Mistral 7B outperforms the best open 13B
model (Llama 2) across all evaluated benchmarks, and the best released 34B
model (Llama 1) in reasoning, mathematics, and code generation. Our model
leverages grouped-query attention (GQA) for faster inference, coupled with sliding
window attention (SWA) to effectively handle sequences of arbitrary length with a
reduced inference cost. We also provide a model fine-tuned to follow instructions,
Mistral 7B – Instruct, that surpasses Llama 2 13B – chat model both on human and
automated benchmarks. Our models are released under the Apache 2.0 license.
Code: <https://github.com/mistralai/mistral-src>
Webpage: <https://mistral.ai/news/announcing-mistral-7b/>"""

template = """{
    "Model": {
        "Name": "",
        "Number of parameters": "",
        "Number of max token": "",
        "Architecture": []
    },
    "Usage": {
        "Use case": [],
        "Licence": ""
    }
}"""

Without kv caching it works in the following way and gives reasonable results (and not amazing given the model size).

def get_text_with_chat_template(text, key):
    prompt = f"### Text:\n{text}\n### Required Key:\n{key}"
    messages = [
        {
            "role": "system", "content": "You are a helpful assistant that can extract values given a requested key. \
            Be concise and precise. \
            Don't repeat the key in the answer."
        },
        {"role": "user", "content": prompt}
    ]
    text_with_chat_template = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    return text_with_chat_template

@torch.inference_mode()
def get_key_value(text_with_chat_template):
    batch_encodings = tokenizer([text_with_chat_template], return_tensors="pt", truncation=True, padding=True, max_length=1000).to(model.device)
    pred_ids = model.generate(**batch_encodings, max_new_tokens=200)
    output = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    return output[-1].split("assistant\n")[1].strip()

keys = ["Model Name", "Number of parameters", "Number of max token", "Architecture", "Use case", "Licence"]
for key in keys:
    text_with_chat_template = get_text_with_chat_template(text, key)
    print(key, ": ", get_key_value(text_with_chat_template))

I can do slightly faster than above by caching the query text and a bit more:

root_prompt = text_with_chat_template.split("Required Key:\n")[0] + "Required Key:\n"
root_inputs = tokenizer(text=root_prompt, padding="longest", return_tensors="pt").to(device)
with torch.inference_mode():
    kv_cache = model(**root_inputs, return_dict=True).past_key_values

prompt_end = "<|im_end|>\n<|im_start|>assistant\n"

with torch.inference_mode():
    for key in keys:
        batch_encodings = tokenizer(text= key + prompt_end, padding=True, truncation=True, return_tensors="pt").to(device)
        batch_encodings["input_ids"] = torch.cat([root_inputs["input_ids"], batch_encodings["input_ids"]], dim=-1)
        batch_encodings["attention_mask"] = torch.cat([root_inputs["attention_mask"], batch_encodings["attention_mask"]], dim=-1)
        pred_ids = model.generate(**batch_encodings, past_key_values=kv_cache, max_new_tokens=200)
        output = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
        print(key, ": ", output[0].split("assistant\n")[1].strip())

# Model Name :  Mistral 7B
# Number of parameters :  7
# Number of max token :  7
# Architecture :  Language Model Architecture
# Use case :  Superior performance and efficiency
# Licence :  Apache 2.0

which gives similar results to the naive for loop.

Question

Now the issue comes when I try to batch the above. This may be due to how everything is padded on the right, or it may be due to the kv caching itself. But I am unsure how to fix it. This is what I have tried:

expanded_kv_cache = tuple(
    (
        k.expand(len(keys), -1, -1, -1), 
        v.expand(len(keys), -1, -1, -1)
    ) 
    for k, v in kv_cache
)

batch_encodings = tokenizer(
    text= [key + prompt_end for key in keys], 
    padding=True, 
    truncation=True, 
    return_tensors="pt"
).to(device)
batch_encodings["input_ids"] = torch.cat([root_inputs["input_ids"].expand(len(keys), -1), batch_encodings["input_ids"]], dim=-1)
batch_encodings["attention_mask"] = torch.cat([root_inputs["attention_mask"].expand(len(keys), -1), batch_encodings["attention_mask"]], dim=-1)

pred_ids = model.generate(**batch_encodings, past_key_values=expanded_kv_cache, max_new_tokens=20)
outputs = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
for output, key in zip(outputs, keys):
    print(key, " : ", output.split("assistant\n")[1].strip())

# Model Name  :  Human
# Number of parameters  :  Human
# Number of max token  :  7
# Architecture  :  Human Language Model
# Use case  :  Human and automated
# Licence  :  Human & Automated

I tried shifting the padding to the left with even worse results

def shift_zeros_left(input_ids, attention_mask):
    # Get the sorted indices for each row in attention_mask (sort by descending order)
    sorted_indices = attention_mask.argsort(dim=1, descending=False)
    # Reorder both input_ids and attention_mask based on sorted indices
    shifted_input_ids = torch.gather(input_ids, 1, sorted_indices)
    shifted_attention_mask = torch.gather(attention_mask, 1, sorted_indices)
    
    return shifted_input_ids, shifted_attention_mask

shifted_input_ids, shifted_attention_mask = shift_zeros_left(
    batch_encodings['input_ids'], 
    batch_encodings['attention_mask']
)
pred_ids = model.generate(input_ids=shifted_input_ids, attention_mask=shifted_attention_mask, past_key_values=expanded_kv_cache, max_new_tokens=20)
outputs = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
for output, key in zip(outputs, keys):
    print(key, " : ", output.split("assistant\n")[1].strip())

# Model Name  :  concise precise precise precise
# Number of parameters  :  concise precise precision
# Number of max token  :  7
# Architecture  :  Conunductiveness Given a requested key.             Be concise and precise.
Use case  :  Constance is concise.
Licence  :  concise precise precise

TIA

2 Likes

For anyone who encounters this question, see cache wrong code · Issue #34232 · huggingface/transformers · GitHub

TL:DR bacthing is not possible for now due to padding issues and will be supported in later releases

1 Like