Efficient batch inference using stacked past_key_values
for multiple continuation candidates
Hi all,
I’m working on a task where, for each position in a sequence, I want to evaluate multiple possible token continuations. These continuations share the same prefix up to that point.
To speed things up, I compute the past_key_values
incrementally for each position, and then for all token candidates at that position, I reuse the same cache. I collect all candidate tokens into a batch and run a single forward pass using their individual input_ids
and a stacked version of their corresponding past_key_values
.
Here’s a simplified description of the approach:
For each position t
in a sequence:
- Compute
past_key_values
up to that point. - For each candidate token:
- Create a one-token
input_ids
with that candidate. - Store the candidate’s
input_ids
and the correspondingpast_key_values
.
- Create a one-token
- Stack all candidate inputs and their
past_key_values
into a batch. - Run: model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)
My question is:
Is this approach aligned with how
past_key_values
are expected to be used in batch inference?
Are there potential pitfalls I should be aware of when batching multiple instances that share different
past_key_values
but the same position in the context?
Any references to examples using similar logic would be appreciated!
Thanks in advance
def predict_each_candidate(cur_topk_idx, cur_ids, model, lm_head):
# Efficiently evaluate multiple candidate tokens for each position in cur_ids,
# using stacked past_key_values and batch inference.
past = None
candidate_batches = []
candidate_pasts = []
candidate_map = []
for t, (gt_tok, pos_tok_list) in enumerate(zip(cur_ids, cur_topk_idx)):
gt_tok_tensor = gt_tok.view(1, 1).to(device)
with torch.inference_mode():
out = model(input_ids=gt_tok_tensor, past_key_values=past, use_cache=True)
past = out.past_key_values
for option_i, tok in enumerate(pos_tok_list):
if tok.item() in list_ids:
input_ids = torch.tensor([[tok.item()]], device=device)
candidate_batches.append(input_ids)
candidate_pasts.append(past)
candidate_map.append((t, option_i))
if len(candidate_batches) > 0:
batch_input_ids = torch.cat(candidate_batches, dim=0)
def stack_past_key_values(past_list):
num_layers = len(past_list[0])
stacked = []
for layer_idx in range(num_layers):
keys = torch.stack([p[layer_idx][0] for p in past_list], dim=0)
values = torch.stack([p[layer_idx][1] for p in past_list], dim=0)
stacked.append((keys, values))
return tuple(stacked)
stacked_past = stack_past_key_values(candidate_pasts)
with torch.inference_mode():
out = model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)
logits = lm_head(out.last_hidden_state)