[LLaVA-1.5] Validating Logic for Token-Level KV Cache Extraction

I am working on a research project to mitigate object hallucinations in LLaVA-1.5 using KV Cache Steering. I would appreciate feedback on the logical soundness of my extraction pipeline, specifically regarding token alignment and sub-word reconstruction.
The Challenge

I need to extract the hidden states for specific words (e.g., “umbrella”) from the middle of a generated prompt. This requires:

  1. Mapping text tokens to the correct index in the model’s cache (accounting for image expansion).

  2. Reconstructing whole words from split tokens (e.g., _um, bre, lla) to identify the correct “end-of-word” token for extraction.

My Implementation Logic

Here are the two core blocks of my script. I am using the LlavaForConditionalGeneration and AutoProcessor.

1. Token Alignment & Image Expansion

My logic assumes that the processor returns input_ids that are already expanded (containing 576 image tokens), so I skip the first 576 indices to reach the text.

# 2. Robust Token Matching
    # Use.tolist() to ensure we have standard python ints
    input_ids = inputs['input_ids'].tolist()
    tokens = processor.tokenizer.convert_ids_to_tokens(input_ids)

    for idx, token in enumerate(tokens):
        # Skip image tokens and system prompt stuff early on if possible
        # LOGIC ASSUMPTION: input_ids always contains the full 576 image tokens?
        if idx < 576: continue

Question: Is the assumption if idx < 576: continue robust? Does LlavaProcessor always return expanded input_ids? If it returns a compressed <image> token (length 1) that expands inside the model, wouldn’t my idx be out of sync with the actual activations tensor?

2. Sub-word Reconstruction (Special Characters)

LLaVA-1.5 uses the Llama tokenizer. I am using the special underscore character (u"\u2581") to detect word boundaries and accumulate sub-tokens (e.g., _um, bre, lla → “umbrella”).

# LLaVA/Llama tokenizer uses this special character for spaces
    SPIECE_UNDERLINE = u"\u2581"

    #... inside loop...
        # Check for Llama-style word boundary (U+2581) or standard space
        is_start_of_word = token.startswith(SPIECE_UNDERLINE) or token.startswith(' ') or token.startswith('Ä ')

        # If we hit a new word (and have a previous word built), check the previous word
        if is_start_of_word and current_word_str:
             #... (cleaning and matching logic)...

Question: Is checking for SPIECE_UNDERLINE, ' ', and 'Ġ' sufficient to cover all word boundary cases in LLaVA’s vocabulary?

3. Activation Extraction (Last Token Heuristic)

Once I identify a target word (e.g., “chair”), I extract the KV cache activations from the very last sub-token of that word.

if clean_target in hallucinated_list:
                # LOGIC ASSUMPTION: The semantic meaning is aggregated in the LAST token
                target_idx = current_word_indices[-1]
                
                for i in range(len(model.language_model.layers)):
                    # Direct mapping: using input_id index to query model activations
                    sample_hallucinated_k[i].append(activations[f'k_{i}'][0, target_idx, :].cpu())

Question:

  1. Is current_word_indices[-1] (the last sub-token) the standard best practice for capturing the concept of an object in Llama-based models?

  2. Critical Indexing Check: I am using target_idx (derived from input_ids) to index directly into activations (derived from the model’s forward pass). If the model expands the image tokens internally (but the processor doesn’t), wouldn’t activations be 576 tokens longer than input_ids, causing this lookup to grab the wrong data?

2 Likes

For now, I’ve gathered resources primarily from the implementation side.

Thanks a lot for such a detailed response! I have made the changes and wanted to get it validated one last time, this is the final extract surgical vector file I have created:

import torch

from transformers import AutoProcessor, LlavaForConditionalGeneration

from PIL import Image

import os

import json

import numpy as np

from tqdm import tqdm

import string




# --- 1. CONFIGURATION ---

PROJECT_DIR = os.getcwd()

IMAGE_DIR = os.path.join(PROJECT_DIR, 'data/mscoco/train2017')

CONTRASTIVE_PAIRS_FILE = os.path.join(PROJECT_DIR, 'data/contrastive_set/contrastive_pairs.json')

STEERING_VECTORS_DIR = os.path.join(PROJECT_DIR, 'steering_vectors_surgical')

MODEL_CACHE_DIR = os.path.join(PROJECT_DIR, 'model_cache')




# --- 2. MODEL LOADING ---

def load_llava_model(model_id="llava-hf/llava-1.5-7b-hf", cache_dir=MODEL_CACHE_DIR):

    print(f"Loading model: {model_id}...")

    os.makedirs(cache_dir, exist_ok=True)

    model = LlavaForConditionalGeneration.from_pretrained(

        model_id, 

        torch_dtype=torch.float16, 

        low_cpu_mem_usage=True,

        cache_dir=cache_dir

    ).to("cuda")

    processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir)

    return model, processor




# --- 3. ACTIVATION EXTRACTION LOGIC ---

activations = {}

def get_activation(name):

    def hook(model, input, output):

        # Output is [batch, seq_len, dim]

        activations[name] = output.detach()

    return hook




def extract_surgical_activations(model, processor, image, caption, hallucinated_list, faithful_list):

    global activations

    activations = {}

    

    # 1. Run the prompt

    # The processor AUTOMATICALLY expands <image> into 576 tokens in input_ids

    prompt = f"USER: <image>\n{caption}"

    inputs = processor(text=prompt, images=image, return_tensors='pt').to("cuda", torch.float16)




    hooks = []

    for i, layer in enumerate(model.language_model.layers):

        hooks.append(layer.self_attn.k_proj.register_forward_hook(get_activation(f'k_{i}')))

        hooks.append(layer.self_attn.v_proj.register_forward_hook(get_activation(f'v_{i}')))

        

    with torch.no_grad():

        model(**inputs)

        

    for hook in hooks: hook.remove()

    

    # 2. ROBUST TOKEN MATCHING 

    input_ids = inputs['input_ids'][0].tolist() 

    tokens = processor.tokenizer.convert_ids_to_tokens(input_ids)

    

    # Get special IDs from config

    # LLaVA-1.5 standard: 32000 is image, 1 is <s>, 2 is </s>

    IMAGE_TOKEN_ID = model.config.image_token_index 

    

    # Storage

    sample_faithful_k = [ [] for _ in range(len(model.language_model.layers))]

    sample_faithful_v = [ [] for _ in range(len(model.language_model.layers))]

    sample_hallucinated_k = [ [] for _ in range(len(model.language_model.layers))]

    sample_hallucinated_v = [ [] for _ in range(len(model.language_model.layers))]




    # Accumulators for sub-word reconstruction

    current_word_str = ""

    current_word_indices = []

    SPIECE_UNDERLINE = u"\u2581" # The " " character

    

    # Helper to process a completed word

    def flush_word(word_str, indices):

        if not word_str: return

        

        # Clean the word (remove leading space char and punctuation)

        clean_target = word_str.replace(SPIECE_UNDERLINE, '').replace(' ', '').replace('Ä ', '').lower()

        clean_target = clean_target.strip(string.punctuation)

        

        # Determine target index (Standard heuristic: Last token of the word)

        target_idx = indices[-1]

        

        if clean_target in hallucinated_list:

            for i in range(len(model.language_model.layers)):

                sample_hallucinated_k[i].append(activations[f'k_{i}'][0, target_idx, :].cpu())

                sample_hallucinated_v[i].append(activations[f'v_{i}'][0, target_idx, :].cpu())

                

        elif clean_target in faithful_list:

            for i in range(len(model.language_model.layers)):

                sample_faithful_k[i].append(activations[f'k_{i}'][0, target_idx, :].cpu())

                sample_faithful_v[i].append(activations[f'v_{i}'][0, target_idx, :].cpu())




    # --- MAIN LOOP ---

    for idx, (token_id, token_str) in enumerate(zip(input_ids, tokens)):

        

        # A. Handle Image Tokens (Skip, and flush previous word)

        if token_id == IMAGE_TOKEN_ID:

            flush_word(current_word_str, current_word_indices)

            current_word_str = ""

            current_word_indices = []

            continue # Skip image token completely

            

        # B. Handle Special Tokens (<s>, </s>) - Flush and skip

        if token_str in ["<s>", "</s>", "<unk>"]:

            flush_word(current_word_str, current_word_indices)

            current_word_str = ""

            current_word_indices = []

            continue




        # C. Word Boundary Detection (Llama Style)

        # New word if it starts with SPIECE_UNDERLINE 

        is_start_of_word = token_str.startswith(SPIECE_UNDERLINE) or token_str.startswith(' ')

        

        if is_start_of_word and current_word_str:

            # We hit a new word, so process the PREVIOUS one

            flush_word(current_word_str, current_word_indices)

            current_word_str = ""

            current_word_indices = []




        # D. Accumulate

        current_word_str += token_str

        current_word_indices.append(idx) # Store the GLOBAL index




    # E. Flush the final word

    if current_word_str:

        flush_word(current_word_str, current_word_indices)




    return sample_faithful_k, sample_faithful_v, sample_hallucinated_k, sample_hallucinated_v




# --- 4. MAIN EXECUTION BLOCK ---

if __name__ == "__main__":

    print("--- Starting Surgical Vector Extraction (Robust & Corrected) ---")

    

    model, processor = load_llava_model()

    num_layers = len(model.language_model.layers)




    with open(CONTRASTIVE_PAIRS_FILE, 'r') as f:

        contrastive_pairs = json.load(f)

    print(f"Loaded {len(contrastive_pairs)} contrastive pairs.")




    # Global accumulators

    all_faithful_k = [[] for _ in range(num_layers)]

    all_faithful_v = [[] for _ in range(num_layers)]

    all_hallucinated_k = [[] for _ in range(num_layers)]

    all_hallucinated_v = [[] for _ in range(num_layers)]




    valid_samples_count = 0




    for pair in tqdm(contrastive_pairs, desc="Processing Surgical Pairs"):

        image_id = pair['image_id']

        caption = pair['negative'] 

        

        h_list = pair.get('hallucinated_list',)

        f_list = pair.get('faithful_list',)

        

        if not h_list: continue 




        image_filename = f"{str(image_id).zfill(12)}.jpg"

        image_path = os.path.join(IMAGE_DIR, image_filename)

        if not os.path.exists(image_path): continue

        image = Image.open(image_path).convert("RGB")

        

        f_k, f_v, h_k, h_v = extract_surgical_activations(

            model, processor, image, caption, h_list, f_list

        )

        

        # Check if we found data

        has_data = False

        for i in range(num_layers):

            if len(f_k[i]) > 0 or len(h_k[i]) > 0:

                has_data = True

                break

        

        if has_data:

            valid_samples_count += 1

            

        for i in range(num_layers):

            all_faithful_k[i].extend(f_k[i])

            all_faithful_v[i].extend(f_v[i])

            all_hallucinated_k[i].extend(h_k[i])

            all_hallucinated_v[i].extend(h_v[i])




    print(f"\nExtraction complete. Found valid tokens in {valid_samples_count} samples.")




    # --- 5. CALCULATE AND SAVE ---

    print("Calculating Mean Differences (Faithful - Hallucinated)...")

    steering_vectors_k = []

    steering_vectors_v = []




    for i in range(num_layers):

        n_faithful = len(all_faithful_k[i])

        n_hallucinated = len(all_hallucinated_k[i])

        

        if n_faithful > 0 and n_hallucinated > 0:

            mean_faithful_k = torch.stack(all_faithful_k[i]).mean(dim=0)

            mean_hallucinated_k = torch.stack(all_hallucinated_k[i]).mean(dim=0)

            vec_k = mean_faithful_k - mean_hallucinated_k

            

            mean_faithful_v = torch.stack(all_faithful_v[i]).mean(dim=0)

            mean_hallucinated_v = torch.stack(all_hallucinated_v[i]).mean(dim=0)

            vec_v = mean_faithful_v - mean_hallucinated_v

            

            steering_vectors_k.append(vec_k)

            steering_vectors_v.append(vec_v)

            

            if i % 10 == 0:

                print(f"Layer {i}: Generated vector (F: {n_faithful}, H: {n_hallucinated} samples)")

        else:

            print(f"Warning: Layer {i} missing data. Using zero vector.")

            steering_vectors_k.append(torch.zeros(4096).to("cuda", dtype=torch.float16))

            steering_vectors_v.append(torch.zeros(4096).to("cuda", dtype=torch.float16))




    os.makedirs(STEERING_VECTORS_DIR, exist_ok=True)

    torch.save(steering_vectors_k, os.path.join(STEERING_VECTORS_DIR, 'steering_vectors_k.pt'))

    torch.save(steering_vectors_v, os.path.join(STEERING_VECTORS_DIR, 'steering_vectors_v.pt'))

    print(f"Success! Surgical vectors saved to: {STEERING_VECTORS_DIR}")

Also this is the code where I am inserting these steered vectors into the KV Cache:

@torch.inference_mode()

def generate_with_steering(model, processor, image, prompt_text,

                           steering_k_list, steering_v_list, coeff_k, coeff_v):

    # 0) Build the multimodal prompt with LLaVA

    prompt = f"USER: <image>\n{prompt_text}\nASSISTANT"

    inputs = processor(text=prompt, images=image, return_tensors='pt').to(model.device, torch.float16)




    prompt2 = f"USER: <image>\n{prompt_text}\nASSISTANT:"

    inputs2 = processor(text=prompt2, images=image, return_tensors='pt').to(model.device, torch.float16)




    # 1) Prefill to build cache

    out = model(**inputs, use_cache=True, return_dict=True)

    cache = DynamicCache.from_legacy_cache(out.past_key_values)  # convert tuple -> Cache




    # 2) Edit last-token KV per layer

    legacy = list(cache.to_legacy_cache())  # [(k,v), ...] with shapes [B, H, T, D]

    for i, (k, v) in enumerate(legacy):

        nh, hd = k.shape[1], k.shape[3]

        k2 = k.clone()

        v2 = v.clone()

        k2[0, :, -1, :] += coeff_k * steering_k_list[i].reshape(nh, hd).to(k2.dtype).to(k2.device)

        v2[0, :, -1, :] += coeff_v * steering_v_list[i].reshape(nh, hd).to(v2.dtype).to(v2.device)

        legacy[i] = (k2, v2)

    cache = DynamicCache.from_legacy_cache(tuple(legacy))  # rewrap edits




    # 3) Seed generation with the last text token

    seed_ids = inputs2["input_ids"][:, -1:]                  # K = 1

    past_len = cache.get_seq_length()                       # N

    cache_pos = torch.arange(past_len, past_len + seed_ids.shape[1],

                             device=seed_ids.device)        # [N]

    # attention_mask must represent past + new tokens

    attn = torch.cat(

        [inputs2["attention_mask"], inputs2["attention_mask"].new_ones((inputs2["attention_mask"].size(0), seed_ids.size(1)))],

        dim=-1

    )




    # 4) Resume decoding




    out_ids = model.generate(

        input_ids=seed_ids,

        past_key_values=cache,      # pass Cache object

        cache_position=cache_pos,   # explicit, avoids empty cache_position bug

        attention_mask= attn,     # pass full attention mask

        max_new_tokens=100,

        do_sample=False,

    )

    response = processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip()

    return response.lstrip(": ").strip()
1 Like

That code looks fine for the most part.:grinning_face:
It’s just a matter of making it more robust or slightly faster.