def predict_each_candidate(cur_topk_idx, cur_ids, model, lm_head, list_ids, device):
“”"
Efficiently evaluate multiple candidate tokens for each position in cur_ids,
using stacked past_key_values and batch inference.
“”"
past = None
candidate_batches =
candidate_pasts =
candidate_map =
for t, (gt_tok, pos_tok_list) in enumerate(zip(cur_ids, cur_topk_idx)):
gt_tok_tensor = gt_tok.view(1, 1).to(device)
with torch.inference_mode():
out = model(input_ids=gt_tok_tensor, past_key_values=past, use_cache=True)
past = out.past_key_values
for option_i, tok in enumerate(pos_tok_list):
if tok.item() in list_ids:
input_ids = torch.tensor([[tok.item()]], device=device)
candidate_batches.append(input_ids)
candidate_pasts.append(past)
candidate_map.append((t, option_i))
if len(candidate_batches) > 0:
batch_input_ids = torch.cat(candidate_batches, dim=0)
def stack_past_key_values(past_list):
num_layers = len(past_list[0])
stacked = []
for layer_idx in range(num_layers):
# Stacks [num_candidates, ...] for each layer
keys = torch.stack([p[layer_idx][0] for p in past_list], dim=0)
values = torch.stack([p[layer_idx][1] for p in past_list], dim=0)
stacked.append((keys, values))
return tuple(stacked)
stacked_past = stack_past_key_values(candidate_pasts)
with torch.inference_mode():
out = model(input_ids=batch_input_ids, past_key_values=stacked_past, use_cache=False)
logits = lm_head(out.last_hidden_state)
return logits, candidate_map # Optionally return mapping
else:
return None, None
Solution provided by Triskel Data Deterministic AI.