Now, I am calling my python class to create the computation graph.
ir = ToyModel(model, (input_ids, attention_mask), dynamic_shapes=dynamic_shapes)
io_data = ir.predict(input_ids, attention_mask)
ir.evaluation()
function inside the model class to compute the things for attention is:
def calculate_self_attention(Q, K, V, masked_fill=None, scale=None, epsilon=1e-9):
"""
Args:
Q, K, V: [B, H, T_q, D] or [B, H, T_k, D]
masked_fill: Optional additive mask of shape [B, 1, T_q, T_k] or [B, H, T_q, T_k]
scale: Optional scaling factor (default: sqrt(D))
epsilon: Small constant for numerical stability
"""
print("Q dtype torch:", Q.dtype)
print("K dtype torch:", K.dtype)
print("V dtype torch:", V.dtype)
B, H, T_q, D = Q.shape
T_k = K.shape[2] # number of key tokens
scale = scale or np.sqrt(D)
log(f"Q: {np.sum(Q):.4f}, K: {np.sum(K):.4f}, V: {np.sum(V):.4f}")
# Step 1: Raw attention logits
QK_output = np.matmul(Q, K.transpose(0, 1, 3, 2)) # [B, H, T_q, T_k]
logits_unmasked = QK_output / scale
print(f"QK_output--- shape: {QK_output.shape}, value: {np.sum(QK_output):.4f}")
print(f"logits_unmasked--- shape: {logits_unmasked.shape}, value: {np.sum(logits_unmasked):.4f}")
###########################################################################
def _nz_stats(name, arr, tol=1e-12):
total = arr.size
zeros = np.count_nonzero(np.abs(arr) < tol)
nonzeros = total - zeros
pct = (nonzeros / total) * 100
print(f"{name}: nonzeros={nonzeros} ({pct:.2f}%), zeros={zeros}")
# Debug: non-zero stats
_nz_stats("Q: ", Q)
_nz_stats("K: ", K)
_nz_stats("V: ", V)
_nz_stats("QK_output: ", QK_output)
###########################################################################
# Step 2: Softmax over unmasked logits (for debugging or interpretability)
A = np.exp(logits_unmasked - np.max(logits_unmasked, axis=-1, keepdims=True))
A = A / (np.sum(A, axis=-1, keepdims=True) + epsilon)
log(f"A (unmasked attention weights) --- shape: {A.shape}, value: {np.sum(A):.4f}")
# Step 3: Apply additive attention mask (optional)
masked_fill = None
if masked_fill is not None:
logits_masked = logits_unmasked + masked_fill # [B, H, T, T] + [B, 1, T, T]
log(f"masked_fill--- minimum: {np.min(masked_fill)}, maximum: {np.max(masked_fill)}")
else:
logits_masked = logits_unmasked.copy()
log(f"logits_masked --- shape: {logits_masked.shape}, value: {np.sum(logits_masked):.4f}")
# Step 4: Softmax over masked logits
A_masked = np.exp(logits_masked - np.max(logits_masked, axis=-1, keepdims=True))
A_masked = A_masked / (np.sum(A_masked, axis=-1, keepdims=True) + epsilon)
log(f"A_masked (masked attention weights)--- shape: {A_masked.shape}, value: {np.sum(A_masked):.4f}")
# Step 5: Compute attention output using masked weights
attention_output = np.matmul(A_masked, V) # [B, H, T_q, D]
log(f"attention_output (using A_masked)--- shape: {attention_output.shape}, value: {np.sum(attention_output):.4f}")
Output logs I am getting for Gemma3 model:
ToyModel: node='scaled_dot_product_attention_25', layer='Attention', func='scaled_dot_product_attention', parents='['clone_101', 'clone_102', 'clone_103', 'slice_494']', children='['transpose_105']'
[DEBUG] number of inputs: 4
[DEBUG] idx: 0, item: torch.Size([1, 4, 2, 256])
[DEBUG] idx: 1, item: torch.Size([1, 4, 2, 256])
[DEBUG] idx: 2, item: torch.Size([1, 4, 2, 256])
[DEBUG] idx: 3, item: torch.Size([1, 1, 2, 2])
[DEBUG] Q: (1, 4, 2, 256), K: (1, 4, 2, 256), V: (1, 4, 2, 256), masked_fill: (1, 1, 2, 2)
Q dtype torch: float32
K dtype torch: float32
V dtype torch: float32
[DEBUG] Q: 0.0000, K: -0.0000, V: 0.0000
QK_output--- shape: (1, 4, 2, 2), value: 0.0000
logits_unmasked--- shape: (1, 4, 2, 2), value: 0.0000
Q: : nonzeros=1013 (49.46%), zeros=1035
K: : nonzeros=1024 (50.00%), zeros=1024
V: : nonzeros=1024 (50.00%), zeros=1024
QK_output: : nonzeros=0 (0.00%), zeros=16
[DEBUG] A (unmasked attention weights) --- shape: (1, 4, 2, 2), value: 8.0000
[DEBUG] logits_masked --- shape: (1, 4, 2, 2), value: 0.0000
[DEBUG] A_masked (masked attention weights)--- shape: (1, 4, 2, 2), value: 8.0000
[DEBUG] attention_output (using A_masked)--- shape: (1, 4, 2, 256), value: 0.0000