How to get normal LLava-1.6 attention maps?

Hi~

I want to use attention maps to visualize the relationships between tokens like this map:

Then I set the keywords output_attentions=True, return_dict_in_generate=True, hoping to get the corresponding attention map.

The code are presented as follows:

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
import torch
from PIL import Image
import requests

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0")
print("done")

path = "/mnt/workspace/workgroup/lz/111.jpg"
image = Image.open(path)

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "What is shown in this image?"},
        ],
    },
]
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(image, prompt, return_tensors="pt").to("cuda:0")

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=256, output_attentions=True, return_dict_in_generate=True)

for attn in output.attentions[-1]:
    print(attn.size()) # 1, 32, 1 ,2353

But it does not follow the shape “(batch_size, num_heads, sequence_length, sequence_length)” mentioned in transformers/src/transformers/models/llava_next/modeling_llava_next.py line 177

Ideally, the output attention map should be something like [1,32,768,768] in size instead of [1,32,1,2535].

How can I get the attention map in normal 2D size? Why do I get outputs of size [1,32,1,2353] ?

1 Like

I asked Hugging Chat. I’m not sure if this is the right answer, but…


To get the attention map in a normal 2D size when working with the LlavaNext model, you need to reshape the output attention tensor. The shape [1, 32, 1, 2353] indicates that the model outputs attention weights for a single query token across all token positions in the sequence. To visualize this as a 2D attention map, you can extract the attention weights for a single attention head and reshape them into a 2D matrix.

Here’s how you can modify your code to achieve this:

import matplotlib.pyplot as plt
import numpy as np

# After generating the output:
output = model.generate(**inputs, max_new_tokens=256, output_attentions=True, return_dict_in_generate=True)

# Access the last attention layer (last decoder layer)
last_attentions = output.attentions[-1]

# Take the first attention head (you can adjust this if needed)
attention_weights = last_attentions[0, 0, :, :]  # Shape: [1, 1, 2353] -> [2353]

# Reshape the attention weights into a 2D matrix (for visualization)
# Assuming the tokens are in a square-like arrangement
size = int(np.sqrt(attention_weights.shape[0]))
attention_map = attention_weights.reshape(size, size)

# Plot the attention map
plt.imshow(attention_map, cmap="viridis")
plt.colorbar()
plt.title("Attention Map")
plt.show()

Explanation:

  1. Accessing Attention Weights: The output.attentions contains the attention weights from all attention layers. The last element (output.attentions[-1]) corresponds to the attention weights from the last decoder layer.
  2. Reshaping: The attention weights are reshaped into a 2D matrix for visualization. The size variable assumes the tokens are arranged in a square-like manner.
  3. Plotting: The attention map is plotted using matplotlib.

This approach will give you a 2D visualization of the attention mechanism applied to the input tokens. If you want to visualize attention for a specific part of the sequence or a different attention head, you can adjust the indices accordingly.

Why the Output Shape is [1, 32, 1, 2353]?

The model processes the input sequence with both text and visual tokens, resulting in a long sequence of tokens (2353). The attention mechanism is computed for each attention head (32 heads), and for each query token in the sequence. In this case, the output shape reflects the attention weights for all 32 heads, across all tokens, with a single query token dimension.