Getting zero attention in attention_module of Gemma3

I am using a Gemma-3 model for an experiment.

Firstly, I am using torch export to capture the computational graph. Then, I am extracting inputs, weights, biases (if present) and outputs of each layer.

Things to keep in mind: this graph contains the basic PyTorch operations as nodes.

I am getting QK_ouput as 0. The same code is running for BERT and Llama family models.

What is wrong here?

My code:

import osimport torch
import torch.nn as nn
from torch.export import Dim
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import login

model_id = “google/gemma-3-1b-it”
auth_token = “[Here, place your Huggingface Authentication Token]”


class GemmaWrapper(nn.Module):
    def __init__(self, model_id, token):
        super().init()
        self.model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float32,token=token).eval()

    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits



model = GemmaWrapper(model_id, auth_token)
tokenizer = AutoTokenizer.from_pretrained(model_id, token=auth_token)
tokenizer.pad_token = tokenizer.eos_token

sentences = [“Hello”]

tokens = tokenizer(sentences, return_tensors=“pt”, padding=True, truncation=True)
input_ids = tokens[“input_ids”]
attention_mask = tokens[“attention_mask”]

if len(sentences) > 1:
    batch_dim = Dim(“batch”, min=1, max=len(sentences))
else:
    batch_dim = 1  # Static dimension

seq_dim = Dim(“seq”, min=1, max=input_ids.shape[1])

dynamic_shapes = {
    “input_ids”: {0: batch_dim, 1: seq_dim},
    “attention_mask”: {0: batch_dim, 1: seq_dim},
}
1 Like

Now, I am calling my python class to create the computation graph.

ir = ToyModel(model, (input_ids, attention_mask), dynamic_shapes=dynamic_shapes)

io_data = ir.predict(input_ids, attention_mask)

ir.evaluation()


function inside the model class to compute the things for attention is:

def calculate_self_attention(Q, K, V, masked_fill=None, scale=None, epsilon=1e-9):
    """
    Args:
        Q, K, V: [B, H, T_q, D] or [B, H, T_k, D]
        masked_fill: Optional additive mask of shape [B, 1, T_q, T_k] or [B, H, T_q, T_k]
        scale: Optional scaling factor (default: sqrt(D))
        epsilon: Small constant for numerical stability 
    """
    print("Q dtype torch:", Q.dtype)
    print("K dtype torch:", K.dtype)
    print("V dtype torch:", V.dtype) 

    B, H, T_q, D = Q.shape
    T_k = K.shape[2]  # number of key tokens
    scale = scale or np.sqrt(D)

    log(f"Q: {np.sum(Q):.4f}, K: {np.sum(K):.4f}, V: {np.sum(V):.4f}")

    # Step 1: Raw attention logits
    QK_output = np.matmul(Q, K.transpose(0, 1, 3, 2))  # [B, H, T_q, T_k]
    logits_unmasked = QK_output / scale 
    print(f"QK_output---  shape: {QK_output.shape},  value: {np.sum(QK_output):.4f}")
    print(f"logits_unmasked---  shape: {logits_unmasked.shape},  value: {np.sum(logits_unmasked):.4f}")

    ###########################################################################
    def _nz_stats(name, arr, tol=1e-12):
        total = arr.size
        zeros = np.count_nonzero(np.abs(arr) < tol)
        nonzeros = total - zeros
        pct = (nonzeros / total) * 100
        print(f"{name}: nonzeros={nonzeros} ({pct:.2f}%), zeros={zeros}")

    # Debug: non-zero stats
    _nz_stats("Q: ", Q)
    _nz_stats("K: ", K)
    _nz_stats("V: ", V)
    _nz_stats("QK_output: ", QK_output)
    ###########################################################################

    # Step 2: Softmax over unmasked logits (for debugging or interpretability)
    A = np.exp(logits_unmasked - np.max(logits_unmasked, axis=-1, keepdims=True))
    A = A / (np.sum(A, axis=-1, keepdims=True) + epsilon) 
    log(f"A (unmasked attention weights) ---  shape: {A.shape}, value: {np.sum(A):.4f}")

    # Step 3: Apply additive attention mask (optional)
    masked_fill = None 
    if masked_fill is not None:
        logits_masked = logits_unmasked + masked_fill  # [B, H, T, T] + [B, 1, T, T]
        log(f"masked_fill--- minimum: {np.min(masked_fill)}, maximum: {np.max(masked_fill)}") 
    else:
        logits_masked = logits_unmasked.copy()
    log(f"logits_masked --- shape: {logits_masked.shape},  value: {np.sum(logits_masked):.4f}")

    # Step 4: Softmax over masked logits
    A_masked = np.exp(logits_masked - np.max(logits_masked, axis=-1, keepdims=True))
    A_masked = A_masked / (np.sum(A_masked, axis=-1, keepdims=True) + epsilon) 
    log(f"A_masked (masked attention weights)---  shape: {A_masked.shape},  value: {np.sum(A_masked):.4f}")

    # Step 5: Compute attention output using masked weights
    attention_output = np.matmul(A_masked, V)  # [B, H, T_q, D] 
    log(f"attention_output (using A_masked)---  shape: {attention_output.shape},  value: {np.sum(attention_output):.4f}")

Output logs I am getting for Gemma3 model:

ToyModel: node='scaled_dot_product_attention_25', layer='Attention', func='scaled_dot_product_attention', parents='['clone_101', 'clone_102', 'clone_103', 'slice_494']', children='['transpose_105']'
[DEBUG] number of inputs: 4
[DEBUG] idx: 0, item: torch.Size([1, 4, 2, 256])
[DEBUG] idx: 1, item: torch.Size([1, 4, 2, 256])
[DEBUG] idx: 2, item: torch.Size([1, 4, 2, 256])
[DEBUG] idx: 3, item: torch.Size([1, 1, 2, 2])
[DEBUG] Q: (1, 4, 2, 256), K: (1, 4, 2, 256), V: (1, 4, 2, 256), masked_fill: (1, 1, 2, 2)
Q dtype torch: float32
K dtype torch: float32
V dtype torch: float32
[DEBUG] Q: 0.0000, K: -0.0000, V: 0.0000
QK_output---  shape: (1, 4, 2, 2),  value: 0.0000
logits_unmasked---  shape: (1, 4, 2, 2),  value: 0.0000
Q: : nonzeros=1013 (49.46%), zeros=1035
K: : nonzeros=1024 (50.00%), zeros=1024
V: : nonzeros=1024 (50.00%), zeros=1024
QK_output: : nonzeros=0 (0.00%), zeros=16
[DEBUG] A (unmasked attention weights) ---  shape: (1, 4, 2, 2), value: 8.0000
[DEBUG] logits_masked --- shape: (1, 4, 2, 2),  value: 0.0000
[DEBUG] A_masked (masked attention weights)---  shape: (1, 4, 2, 2),  value: 8.0000
[DEBUG] attention_output (using A_masked)---  shape: (1, 4, 2, 256),  value: 0.0000

1 Like

Output logs I am getting for LLaMA3.2-1B model:

ToyModel: node='scaled_dot_product_attention_15', layer='Attention', func='scaled_dot_product_attention', parents='['clone_63', '_unsafe_view_30', '_unsafe_view_31', 'slice_311']', children='['transpose_64']'
[DEBUG] number of inputs: 4
[DEBUG] idx: 0, item: torch.Size([1, 32, 2, 64])
[DEBUG] idx: 1, item: torch.Size([1, 32, 2, 64])
[DEBUG] idx: 2, item: torch.Size([1, 32, 2, 64])
[DEBUG] idx: 3, item: torch.Size([1, 1, 2, 2])
[DEBUG] Q: (1, 32, 2, 64), K: (1, 32, 2, 64), V: (1, 32, 2, 64), masked_fill: (1, 1, 2, 2)
Q dtype torch: float32
K dtype torch: float32
V dtype torch: float32
[DEBUG] Q: 0.5828, K: 0.7185, V: -2.3097
QK_output---  shape: (1, 32, 2, 2),  value: 0.1283
logits_unmasked---  shape: (1, 32, 2, 2),  value: 0.0160
Q: : nonzeros=4096 (100.00%), zeros=0
K: : nonzeros=4096 (100.00%), zeros=0
V: : nonzeros=4096 (100.00%), zeros=0
QK_output: : nonzeros=128 (100.00%), zeros=0
[DEBUG] A (unmasked attention weights) ---  shape: (1, 32, 2, 2), value: 64.0000
[DEBUG] logits_masked --- shape: (1, 32, 2, 2),  value: 0.0160
[DEBUG] A_masked (masked attention weights)---  shape: (1, 32, 2, 2),  value: 64.0000
[DEBUG] attention_output (using A_masked)---  shape: (1, 32, 2, 64),  value: -2.3097 
1 Like

Hi,

the most likely source of your issue is a small typo in your GemmaWrapper class. In Python, the class constructor must be named __init__ (with double underscores on both sides), not init.

Because of this typo, your self.model = ... line is never actually executed, and your GemmaWrapper instance doesn’t contain the Gemma model.

1 Like

@AerisCodex This was a typo in the pasted code here.

Original code contains:

class GemmaWrapper(nn.Module):
    def __init__(self, model_id, token):
        super().__init__()
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float32,
            token=token
        ).eval() 
    
    def forward(self, input_ids, attention_mask):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=False).logits
1 Like

Hmm… When using SDPA or FlashAttention, attention may not be returned…?

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = "cuda" if torch.cuda.is_available() else "cpu"
#m_id = "google/gemma-3-1b-it"
m_id = "unsloth/gemma-3-1b-it-bnb-4bit"
#unsloth/gemma-3-1b-it # same result in my environment
attn_implementation = "eager"
#attn_implementation = "sdpa"
tok = AutoTokenizer.from_pretrained(m_id)
model = AutoModelForCausalLM.from_pretrained(m_id, device_map="auto", attn_implementation=attn_implementation).eval()

torch.backends.cuda.enable_flash_sdp(False)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(True)

inp = tok(["Hello"], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
    out = model(**inp, output_attentions=True)
print("attn_implementation:", attn_implementation)
if out.attentions: print("attn sum:", sum(a.abs().sum().item() for a in out.attentions if a is not None))
#attn_implementation: sdpa
#attn sum: 0
#attn_implementation: eager
#attn sum: 208.0

@John6666

My issue is related to my custom implementation of the sdpa.

def calculate_self_attention(Q, K, V, masked_fill=None, scale=None, epsilon=1e-9):
    """
    Args:
        Q, K, V: [B, H, T_q, D] or [B, H, T_k, D]
        masked_fill: Optional additive mask of shape [B, 1, T_q, T_k] or [B, H, T_q, T_k]
        scale: Optional scaling factor (default: sqrt(D))
        epsilon: Small constant for numerical stability 
    """
    print("Q dtype torch:", Q.dtype)
    print("K dtype torch:", K.dtype)
    print("V dtype torch:", V.dtype) 

    B, H, T_q, D = Q.shape
    T_k = K.shape[2]  # number of key tokens
    scale = scale or np.sqrt(D)

    log(f"Q: {np.sum(Q):.4f}, K: {np.sum(K):.4f}, V: {np.sum(V):.4f}")

    # Step 1: Raw attention logits
    QK_output = np.matmul(Q, K.transpose(0, 1, 3, 2))  # [B, H, T_q, T_k]
    logits_unmasked = QK_output / scale 
    print(f"QK_output---  shape: {QK_output.shape},  value: {np.sum(QK_output):.4f}")
    print(f"logits_unmasked---  shape: {logits_unmasked.shape},  value: {np.sum(logits_unmasked):.4f}")

    ###########################################################################
    def _nz_stats(name, arr, tol=1e-12):
        total = arr.size
        zeros = np.count_nonzero(np.abs(arr) < tol)
        nonzeros = total - zeros
        pct = (nonzeros / total) * 100
        print(f"{name}: nonzeros={nonzeros} ({pct:.2f}%), zeros={zeros}")

    # Debug: non-zero stats
    _nz_stats("Q: ", Q)
    _nz_stats("K: ", K)
    _nz_stats("V: ", V)
    _nz_stats("QK_output: ", QK_output)
    ###########################################################################

    # Step 2: Softmax over unmasked logits (for debugging or interpretability)
    A = np.exp(logits_unmasked - np.max(logits_unmasked, axis=-1, keepdims=True))
    A = A / (np.sum(A, axis=-1, keepdims=True) + epsilon) 
    log(f"A (unmasked attention weights) ---  shape: {A.shape}, value: {np.sum(A):.4f}")

    # Step 3: Apply additive attention mask (optional)
    masked_fill = None 
    if masked_fill is not None:
        logits_masked = logits_unmasked + masked_fill  # [B, H, T, T] + [B, 1, T, T]
        log(f"masked_fill--- minimum: {np.min(masked_fill)}, maximum: {np.max(masked_fill)}") 
    else:
        logits_masked = logits_unmasked.copy()
    log(f"logits_masked --- shape: {logits_masked.shape},  value: {np.sum(logits_masked):.4f}")

    # Step 4: Softmax over masked logits
    A_masked = np.exp(logits_masked - np.max(logits_masked, axis=-1, keepdims=True))
    A_masked = A_masked / (np.sum(A_masked, axis=-1, keepdims=True) + epsilon) 
    log(f"A_masked (masked attention weights)---  shape: {A_masked.shape},  value: {np.sum(A_masked):.4f}")

    # Step 5: Compute attention output using masked weights
    attention_output = np.matmul(A_masked, V)  # [B, H, T_q, D] 
    log(f"attention_output (using A_masked)---  shape: {attention_output.shape},  value: {np.sum(attention_output):.4f}")

Does this implementation correct?

1 Like

Could this be the reason why QKV isn’t returned in your implementation?

The second main architectural improvement is an increase in context size to 128K tokens, without reducing performance. A challenge with long context is the memory explosion of the KV cache during inference. To reduce this issue, we interleave multiple local layers between each global layer, and assign a smaller span of only 1024 tokens to the local layers. Therefore, only the global layers attend to long context, and we have 1 global for every 5 local layers.

And or perhaps due to this…?

Both Gemma 2 and Gemma 3 utilize Grouped-Query Attention (GQA) with post-norm and pre-norm with RMSNorm. However, Gemma 3 gains both improved accuracy and faster processing speeds by adopting QK-norm in place of Gemma 2’s soft-capping mechanism.

You choose quantized model, i.e. “unsloth/gemma-3-1b-it-bnb-4bit”. In this case, the attn_sum is 0 for sdpa and 208.0 for eager.

I choose “google/gemma-3-1b-it” model, and here, for both cases, I am getting 208.0

1 Like

Also, only issue is QK_output is becoming 0, i.e. QK_output = np.matmul(Q, K.transpose(0, 1, 3, 2)) *# [B, H, T_q, T_k] is zero.

Also, this attention module is traced as* ‘scaled_dot_product_attention’ function with 4 input values.

[DEBUG] number of inputs: 4
[DEBUG] idx: 0, item: torch.Size([1, 4, 41, 256])
[DEBUG] idx: 1, item: torch.Size([1, 4, 41, 256])
[DEBUG] idx: 2, item: torch.Size([1, 4, 41, 256])
[DEBUG] idx: 3, item: torch.Size([1, 1, 41, 41])
1 Like