Efficient batch inference using stacked past_key_values for multiple continuation candidates

Efficient batch inference using stacked past_key_values for multiple continuation candidates

Hi all,

I’m working on a task where, for each position in a sequence, I want to evaluate multiple possible token continuations. These continuations share the same prefix up to that point.

To speed things up, I compute the past_key_values incrementally for each position, and then for all token candidates at that position, I reuse the same cache. I collect all candidate tokens into a batch and run a single forward pass using their individual input_ids and a stacked version of their corresponding past_key_values.

Here’s a simplified description of the approach:

For each position t in a sequence:

  • Compute past_key_values up to that point.
  • For each candidate token:
    • Create a one-token input_ids with that candidate.
    • Store the candidate’s input_ids and the corresponding past_key_values.
  • Stack all candidate inputs and their past_key_values into a batch.
  • Run: model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)

:red_question_mark: My question is:

:backhand_index_pointing_right: Is this approach aligned with how past_key_values are expected to be used in batch inference?
:backhand_index_pointing_right: Are there potential pitfalls I should be aware of when batching multiple instances that share different past_key_values but the same position in the context?

Any references to examples using similar logic would be appreciated!

Thanks in advance :folded_hands:

def predict_each_candidate(cur_topk_idx, cur_ids, model, lm_head): 
    # Efficiently evaluate multiple candidate tokens for each position in cur_ids,
    # using stacked past_key_values and batch inference.
    past = None
    candidate_batches = []
    candidate_pasts = []
    candidate_map = []

    for t, (gt_tok, pos_tok_list) in enumerate(zip(cur_ids, cur_topk_idx)):
        gt_tok_tensor = gt_tok.view(1, 1).to(device)
        with torch.inference_mode():
            out = model(input_ids=gt_tok_tensor, past_key_values=past, use_cache=True)
        past = out.past_key_values

        for option_i, tok in enumerate(pos_tok_list):
            if tok.item() in list_ids:
                input_ids = torch.tensor([[tok.item()]], device=device)
                candidate_batches.append(input_ids)
                candidate_pasts.append(past)
                candidate_map.append((t, option_i))

    if len(candidate_batches) > 0:
        batch_input_ids = torch.cat(candidate_batches, dim=0)

        def stack_past_key_values(past_list):
            num_layers = len(past_list[0])
            stacked = []
            for layer_idx in range(num_layers):
                keys = torch.stack([p[layer_idx][0] for p in past_list], dim=0)
                values = torch.stack([p[layer_idx][1] for p in past_list], dim=0)
                stacked.append((keys, values))
            return tuple(stacked)

        stacked_past = stack_past_key_values(candidate_pasts)

        with torch.inference_mode():
            out = model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)
            logits = lm_head(out.last_hidden_state)
1 Like

def predict_each_candidate(cur_topk_idx, cur_ids, model, lm_head, list_ids, device):
“”"
Efficiently evaluate multiple candidate tokens for each position in cur_ids,
using stacked past_key_values and batch inference.
“”"
past = None
candidate_batches =
candidate_pasts =
candidate_map =

for t, (gt_tok, pos_tok_list) in enumerate(zip(cur_ids, cur_topk_idx)):
    gt_tok_tensor = gt_tok.view(1, 1).to(device)
    with torch.inference_mode():
        out = model(input_ids=gt_tok_tensor, past_key_values=past, use_cache=True)
    past = out.past_key_values

    for option_i, tok in enumerate(pos_tok_list):
        if tok.item() in list_ids:
            input_ids = torch.tensor([[tok.item()]], device=device)
            candidate_batches.append(input_ids)
            candidate_pasts.append(past)
            candidate_map.append((t, option_i))

if len(candidate_batches) > 0:
    batch_input_ids = torch.cat(candidate_batches, dim=0)

    def stack_past_key_values(past_list):
        num_layers = len(past_list[0])
        stacked = []
        for layer_idx in range(num_layers):
            # Stacks [num_candidates, ...] for each layer
            keys = torch.stack([p[layer_idx][0] for p in past_list], dim=0)
            values = torch.stack([p[layer_idx][1] for p in past_list], dim=0)
            stacked.append((keys, values))
        return tuple(stacked)

    stacked_past = stack_past_key_values(candidate_pasts)

    with torch.inference_mode():
        out = model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)
        logits = lm_head(out.last_hidden_state)
    return logits, candidate_map  # Optionally return mapping
else:
    return None, None

Solution provided by Triskel Data Deterministic AI.

1 Like