Thanks a lot for such a detailed response! I have made the changes and wanted to get it validated one last time, this is the final extract surgical vector file I have created:
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import os
import json
import numpy as np
from tqdm import tqdm
import string
# --- 1. CONFIGURATION ---
PROJECT_DIR = os.getcwd()
IMAGE_DIR = os.path.join(PROJECT_DIR, 'data/mscoco/train2017')
CONTRASTIVE_PAIRS_FILE = os.path.join(PROJECT_DIR, 'data/contrastive_set/contrastive_pairs.json')
STEERING_VECTORS_DIR = os.path.join(PROJECT_DIR, 'steering_vectors_surgical')
MODEL_CACHE_DIR = os.path.join(PROJECT_DIR, 'model_cache')
# --- 2. MODEL LOADING ---
def load_llava_model(model_id="llava-hf/llava-1.5-7b-hf", cache_dir=MODEL_CACHE_DIR):
print(f"Loading model: {model_id}...")
os.makedirs(cache_dir, exist_ok=True)
model = LlavaForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
cache_dir=cache_dir
).to("cuda")
processor = AutoProcessor.from_pretrained(model_id, cache_dir=cache_dir)
return model, processor
# --- 3. ACTIVATION EXTRACTION LOGIC ---
activations = {}
def get_activation(name):
def hook(model, input, output):
# Output is [batch, seq_len, dim]
activations[name] = output.detach()
return hook
def extract_surgical_activations(model, processor, image, caption, hallucinated_list, faithful_list):
global activations
activations = {}
# 1. Run the prompt
# The processor AUTOMATICALLY expands <image> into 576 tokens in input_ids
prompt = f"USER: <image>\n{caption}"
inputs = processor(text=prompt, images=image, return_tensors='pt').to("cuda", torch.float16)
hooks = []
for i, layer in enumerate(model.language_model.layers):
hooks.append(layer.self_attn.k_proj.register_forward_hook(get_activation(f'k_{i}')))
hooks.append(layer.self_attn.v_proj.register_forward_hook(get_activation(f'v_{i}')))
with torch.no_grad():
model(**inputs)
for hook in hooks: hook.remove()
# 2. ROBUST TOKEN MATCHING
input_ids = inputs['input_ids'][0].tolist()
tokens = processor.tokenizer.convert_ids_to_tokens(input_ids)
# Get special IDs from config
# LLaVA-1.5 standard: 32000 is image, 1 is <s>, 2 is </s>
IMAGE_TOKEN_ID = model.config.image_token_index
# Storage
sample_faithful_k = [ [] for _ in range(len(model.language_model.layers))]
sample_faithful_v = [ [] for _ in range(len(model.language_model.layers))]
sample_hallucinated_k = [ [] for _ in range(len(model.language_model.layers))]
sample_hallucinated_v = [ [] for _ in range(len(model.language_model.layers))]
# Accumulators for sub-word reconstruction
current_word_str = ""
current_word_indices = []
SPIECE_UNDERLINE = u"\u2581" # The " " character
# Helper to process a completed word
def flush_word(word_str, indices):
if not word_str: return
# Clean the word (remove leading space char and punctuation)
clean_target = word_str.replace(SPIECE_UNDERLINE, '').replace(' ', '').replace('Ä ', '').lower()
clean_target = clean_target.strip(string.punctuation)
# Determine target index (Standard heuristic: Last token of the word)
target_idx = indices[-1]
if clean_target in hallucinated_list:
for i in range(len(model.language_model.layers)):
sample_hallucinated_k[i].append(activations[f'k_{i}'][0, target_idx, :].cpu())
sample_hallucinated_v[i].append(activations[f'v_{i}'][0, target_idx, :].cpu())
elif clean_target in faithful_list:
for i in range(len(model.language_model.layers)):
sample_faithful_k[i].append(activations[f'k_{i}'][0, target_idx, :].cpu())
sample_faithful_v[i].append(activations[f'v_{i}'][0, target_idx, :].cpu())
# --- MAIN LOOP ---
for idx, (token_id, token_str) in enumerate(zip(input_ids, tokens)):
# A. Handle Image Tokens (Skip, and flush previous word)
if token_id == IMAGE_TOKEN_ID:
flush_word(current_word_str, current_word_indices)
current_word_str = ""
current_word_indices = []
continue # Skip image token completely
# B. Handle Special Tokens (<s>, </s>) - Flush and skip
if token_str in ["<s>", "</s>", "<unk>"]:
flush_word(current_word_str, current_word_indices)
current_word_str = ""
current_word_indices = []
continue
# C. Word Boundary Detection (Llama Style)
# New word if it starts with SPIECE_UNDERLINE
is_start_of_word = token_str.startswith(SPIECE_UNDERLINE) or token_str.startswith(' ')
if is_start_of_word and current_word_str:
# We hit a new word, so process the PREVIOUS one
flush_word(current_word_str, current_word_indices)
current_word_str = ""
current_word_indices = []
# D. Accumulate
current_word_str += token_str
current_word_indices.append(idx) # Store the GLOBAL index
# E. Flush the final word
if current_word_str:
flush_word(current_word_str, current_word_indices)
return sample_faithful_k, sample_faithful_v, sample_hallucinated_k, sample_hallucinated_v
# --- 4. MAIN EXECUTION BLOCK ---
if __name__ == "__main__":
print("--- Starting Surgical Vector Extraction (Robust & Corrected) ---")
model, processor = load_llava_model()
num_layers = len(model.language_model.layers)
with open(CONTRASTIVE_PAIRS_FILE, 'r') as f:
contrastive_pairs = json.load(f)
print(f"Loaded {len(contrastive_pairs)} contrastive pairs.")
# Global accumulators
all_faithful_k = [[] for _ in range(num_layers)]
all_faithful_v = [[] for _ in range(num_layers)]
all_hallucinated_k = [[] for _ in range(num_layers)]
all_hallucinated_v = [[] for _ in range(num_layers)]
valid_samples_count = 0
for pair in tqdm(contrastive_pairs, desc="Processing Surgical Pairs"):
image_id = pair['image_id']
caption = pair['negative']
h_list = pair.get('hallucinated_list',)
f_list = pair.get('faithful_list',)
if not h_list: continue
image_filename = f"{str(image_id).zfill(12)}.jpg"
image_path = os.path.join(IMAGE_DIR, image_filename)
if not os.path.exists(image_path): continue
image = Image.open(image_path).convert("RGB")
f_k, f_v, h_k, h_v = extract_surgical_activations(
model, processor, image, caption, h_list, f_list
)
# Check if we found data
has_data = False
for i in range(num_layers):
if len(f_k[i]) > 0 or len(h_k[i]) > 0:
has_data = True
break
if has_data:
valid_samples_count += 1
for i in range(num_layers):
all_faithful_k[i].extend(f_k[i])
all_faithful_v[i].extend(f_v[i])
all_hallucinated_k[i].extend(h_k[i])
all_hallucinated_v[i].extend(h_v[i])
print(f"\nExtraction complete. Found valid tokens in {valid_samples_count} samples.")
# --- 5. CALCULATE AND SAVE ---
print("Calculating Mean Differences (Faithful - Hallucinated)...")
steering_vectors_k = []
steering_vectors_v = []
for i in range(num_layers):
n_faithful = len(all_faithful_k[i])
n_hallucinated = len(all_hallucinated_k[i])
if n_faithful > 0 and n_hallucinated > 0:
mean_faithful_k = torch.stack(all_faithful_k[i]).mean(dim=0)
mean_hallucinated_k = torch.stack(all_hallucinated_k[i]).mean(dim=0)
vec_k = mean_faithful_k - mean_hallucinated_k
mean_faithful_v = torch.stack(all_faithful_v[i]).mean(dim=0)
mean_hallucinated_v = torch.stack(all_hallucinated_v[i]).mean(dim=0)
vec_v = mean_faithful_v - mean_hallucinated_v
steering_vectors_k.append(vec_k)
steering_vectors_v.append(vec_v)
if i % 10 == 0:
print(f"Layer {i}: Generated vector (F: {n_faithful}, H: {n_hallucinated} samples)")
else:
print(f"Warning: Layer {i} missing data. Using zero vector.")
steering_vectors_k.append(torch.zeros(4096).to("cuda", dtype=torch.float16))
steering_vectors_v.append(torch.zeros(4096).to("cuda", dtype=torch.float16))
os.makedirs(STEERING_VECTORS_DIR, exist_ok=True)
torch.save(steering_vectors_k, os.path.join(STEERING_VECTORS_DIR, 'steering_vectors_k.pt'))
torch.save(steering_vectors_v, os.path.join(STEERING_VECTORS_DIR, 'steering_vectors_v.pt'))
print(f"Success! Surgical vectors saved to: {STEERING_VECTORS_DIR}")
Also this is the code where I am inserting these steered vectors into the KV Cache:
@torch.inference_mode()
def generate_with_steering(model, processor, image, prompt_text,
steering_k_list, steering_v_list, coeff_k, coeff_v):
# 0) Build the multimodal prompt with LLaVA
prompt = f"USER: <image>\n{prompt_text}\nASSISTANT"
inputs = processor(text=prompt, images=image, return_tensors='pt').to(model.device, torch.float16)
prompt2 = f"USER: <image>\n{prompt_text}\nASSISTANT:"
inputs2 = processor(text=prompt2, images=image, return_tensors='pt').to(model.device, torch.float16)
# 1) Prefill to build cache
out = model(**inputs, use_cache=True, return_dict=True)
cache = DynamicCache.from_legacy_cache(out.past_key_values) # convert tuple -> Cache
# 2) Edit last-token KV per layer
legacy = list(cache.to_legacy_cache()) # [(k,v), ...] with shapes [B, H, T, D]
for i, (k, v) in enumerate(legacy):
nh, hd = k.shape[1], k.shape[3]
k2 = k.clone()
v2 = v.clone()
k2[0, :, -1, :] += coeff_k * steering_k_list[i].reshape(nh, hd).to(k2.dtype).to(k2.device)
v2[0, :, -1, :] += coeff_v * steering_v_list[i].reshape(nh, hd).to(v2.dtype).to(v2.device)
legacy[i] = (k2, v2)
cache = DynamicCache.from_legacy_cache(tuple(legacy)) # rewrap edits
# 3) Seed generation with the last text token
seed_ids = inputs2["input_ids"][:, -1:] # K = 1
past_len = cache.get_seq_length() # N
cache_pos = torch.arange(past_len, past_len + seed_ids.shape[1],
device=seed_ids.device) # [N]
# attention_mask must represent past + new tokens
attn = torch.cat(
[inputs2["attention_mask"], inputs2["attention_mask"].new_ones((inputs2["attention_mask"].size(0), seed_ids.size(1)))],
dim=-1
)
# 4) Resume decoding
out_ids = model.generate(
input_ids=seed_ids,
past_key_values=cache, # pass Cache object
cache_position=cache_pos, # explicit, avoids empty cache_position bug
attention_mask= attn, # pass full attention mask
max_new_tokens=100,
do_sample=False,
)
response = processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip()
return response.lstrip(": ").strip()