I am trying to do some strucutured text extraction using some kv caching tricks. For this example I will use the following model and data:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# data
text = """We introduce Mistral 7B, a 7–billion-parameter language model engineered for
superior performance and efficiency. Mistral 7B outperforms the best open 13B
model (Llama 2) across all evaluated benchmarks, and the best released 34B
model (Llama 1) in reasoning, mathematics, and code generation. Our model
leverages grouped-query attention (GQA) for faster inference, coupled with sliding
window attention (SWA) to effectively handle sequences of arbitrary length with a
reduced inference cost. We also provide a model fine-tuned to follow instructions,
Mistral 7B – Instruct, that surpasses Llama 2 13B – chat model both on human and
automated benchmarks. Our models are released under the Apache 2.0 license.
Code: <https://github.com/mistralai/mistral-src>
Webpage: <https://mistral.ai/news/announcing-mistral-7b/>"""
template = """{
"Model": {
"Name": "",
"Number of parameters": "",
"Number of max token": "",
"Architecture": []
},
"Usage": {
"Use case": [],
"Licence": ""
}
}"""
Without kv caching it works in the following way and gives reasonable results (and not amazing given the model size).
def get_text_with_chat_template(text, key):
prompt = f"### Text:\n{text}\n### Required Key:\n{key}"
messages = [
{
"role": "system", "content": "You are a helpful assistant that can extract values given a requested key. \
Be concise and precise. \
Don't repeat the key in the answer."
},
{"role": "user", "content": prompt}
]
text_with_chat_template = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
return text_with_chat_template
@torch.inference_mode()
def get_key_value(text_with_chat_template):
batch_encodings = tokenizer([text_with_chat_template], return_tensors="pt", truncation=True, padding=True, max_length=1000).to(model.device)
pred_ids = model.generate(**batch_encodings, max_new_tokens=200)
output = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
return output[-1].split("assistant\n")[1].strip()
keys = ["Model Name", "Number of parameters", "Number of max token", "Architecture", "Use case", "Licence"]
for key in keys:
text_with_chat_template = get_text_with_chat_template(text, key)
print(key, ": ", get_key_value(text_with_chat_template))
I can do slightly faster than above by caching the query text and a bit more:
root_prompt = text_with_chat_template.split("Required Key:\n")[0] + "Required Key:\n"
root_inputs = tokenizer(text=root_prompt, padding="longest", return_tensors="pt").to(device)
with torch.inference_mode():
kv_cache = model(**root_inputs, return_dict=True).past_key_values
prompt_end = "<|im_end|>\n<|im_start|>assistant\n"
with torch.inference_mode():
for key in keys:
batch_encodings = tokenizer(text= key + prompt_end, padding=True, truncation=True, return_tensors="pt").to(device)
batch_encodings["input_ids"] = torch.cat([root_inputs["input_ids"], batch_encodings["input_ids"]], dim=-1)
batch_encodings["attention_mask"] = torch.cat([root_inputs["attention_mask"], batch_encodings["attention_mask"]], dim=-1)
pred_ids = model.generate(**batch_encodings, past_key_values=kv_cache, max_new_tokens=200)
output = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
print(key, ": ", output[0].split("assistant\n")[1].strip())
# Model Name : Mistral 7B
# Number of parameters : 7
# Number of max token : 7
# Architecture : Language Model Architecture
# Use case : Superior performance and efficiency
# Licence : Apache 2.0
which gives similar results to the naive for loop.
Question
Now the issue comes when I try to batch the above. This may be due to how everything is padded on the right, or it may be due to the kv caching itself. But I am unsure how to fix it. This is what I have tried:
expanded_kv_cache = tuple(
(
k.expand(len(keys), -1, -1, -1),
v.expand(len(keys), -1, -1, -1)
)
for k, v in kv_cache
)
batch_encodings = tokenizer(
text= [key + prompt_end for key in keys],
padding=True,
truncation=True,
return_tensors="pt"
).to(device)
batch_encodings["input_ids"] = torch.cat([root_inputs["input_ids"].expand(len(keys), -1), batch_encodings["input_ids"]], dim=-1)
batch_encodings["attention_mask"] = torch.cat([root_inputs["attention_mask"].expand(len(keys), -1), batch_encodings["attention_mask"]], dim=-1)
pred_ids = model.generate(**batch_encodings, past_key_values=expanded_kv_cache, max_new_tokens=20)
outputs = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
for output, key in zip(outputs, keys):
print(key, " : ", output.split("assistant\n")[1].strip())
# Model Name : Human
# Number of parameters : Human
# Number of max token : 7
# Architecture : Human Language Model
# Use case : Human and automated
# Licence : Human & Automated
I tried shifting the padding to the left with even worse results
def shift_zeros_left(input_ids, attention_mask):
# Get the sorted indices for each row in attention_mask (sort by descending order)
sorted_indices = attention_mask.argsort(dim=1, descending=False)
# Reorder both input_ids and attention_mask based on sorted indices
shifted_input_ids = torch.gather(input_ids, 1, sorted_indices)
shifted_attention_mask = torch.gather(attention_mask, 1, sorted_indices)
return shifted_input_ids, shifted_attention_mask
shifted_input_ids, shifted_attention_mask = shift_zeros_left(
batch_encodings['input_ids'],
batch_encodings['attention_mask']
)
pred_ids = model.generate(input_ids=shifted_input_ids, attention_mask=shifted_attention_mask, past_key_values=expanded_kv_cache, max_new_tokens=20)
outputs = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
for output, key in zip(outputs, keys):
print(key, " : ", output.split("assistant\n")[1].strip())
# Model Name : concise precise precise precise
# Number of parameters : concise precise precision
# Number of max token : 7
# Architecture : Conunductiveness Given a requested key. Be concise and precise.
Use case : Constance is concise.
Licence : concise precise precise
TIA