Efficient batch inference using stacked past_key_values for multiple continuation candidates

def predict_each_candidate(cur_topk_idx, cur_ids, model, lm_head, list_ids, device):
“”"
Efficiently evaluate multiple candidate tokens for each position in cur_ids,
using stacked past_key_values and batch inference.
“”"
past = None
candidate_batches =
candidate_pasts =
candidate_map =

for t, (gt_tok, pos_tok_list) in enumerate(zip(cur_ids, cur_topk_idx)):
    gt_tok_tensor = gt_tok.view(1, 1).to(device)
    with torch.inference_mode():
        out = model(input_ids=gt_tok_tensor, past_key_values=past, use_cache=True)
    past = out.past_key_values

    for option_i, tok in enumerate(pos_tok_list):
        if tok.item() in list_ids:
            input_ids = torch.tensor([[tok.item()]], device=device)
            candidate_batches.append(input_ids)
            candidate_pasts.append(past)
            candidate_map.append((t, option_i))

if len(candidate_batches) > 0:
    batch_input_ids = torch.cat(candidate_batches, dim=0)

    def stack_past_key_values(past_list):
        num_layers = len(past_list[0])
        stacked = []
        for layer_idx in range(num_layers):
            # Stacks [num_candidates, ...] for each layer
            keys = torch.stack([p[layer_idx][0] for p in past_list], dim=0)
            values = torch.stack([p[layer_idx][1] for p in past_list], dim=0)
            stacked.append((keys, values))
        return tuple(stacked)

    stacked_past = stack_past_key_values(candidate_pasts)

    with torch.inference_mode():
        out = model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)
        logits = lm_head(out.last_hidden_state)
    return logits, candidate_map  # Optionally return mapping
else:
    return None, None

Solution provided by Triskel Data Deterministic AI.

1 Like