Getting zero gradients for image patch embeddings when implementing GRADCAM for ViLT

I am trying to implement GRADCAM for ViLT (specfically for ViltForQuestionAnswering).

I am taking gradients and activations from the last layernorm layer in the model:

  (vilt): ViltModel(
    (embeddings): ViltEmbeddings(
      (text_embeddings): TextEmbeddings(
        (word_embeddings): Embedding(30522, 768)
        (position_embeddings): Embedding(40, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.0, inplace=False)
      (patch_embeddings): ViltPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(32, 32), stride=(32, 32))
      (token_type_embeddings): Embedding(2, 768)
      (dropout): Dropout(p=0.0, inplace=False)
    (encoder): ViltEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViltLayer(
          (attention): ViltAttention(
            (attention): ViltSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            (output): ViltSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
          (intermediate): ViltIntermediate(
            (dense): Linear(in_features=768, out_features=3072, bias=True)
            (intermediate_act_fn): GELUActivation()
          (output): ViltOutput(
            (dense): Linear(in_features=3072, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          (layernorm_before): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (layernorm_after): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (layernorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (pooler): ViltPooler(
      (dense): Linear(in_features=768, out_features=768, bias=True)
      (activation): Tanh()
  (classifier): Sequential(
    (0): Linear(in_features=768, out_features=1536, bias=True)
    (1): LayerNorm((1536,), eps=1e-05, elementwise_affine=True)
    (2): GELU(approximate='none')
    (3): Linear(in_features=1536, out_features=3129, bias=True)

The layernorm before the last Pooler module.

I get a tensor of shape [batch_sz, embeddings, 768]. Here embeddings = text_embeddings + image_embeddings. When I look at the gradients, I can see non-zero gradients for only the first dimension of embeddings.

print(f"Non-zero gradients for the first embedding: {gradients[0, 0, :].count_nonzero()}")
Non-zero gradients for the first embedding: 768

print(f"Non-zero gradients for rest of the embeddings: {gradients[0, 1:, :].count_nonzero()}")
Non-zero gradients for rest of the embeddings: 0

I need gradients from the image embedding to be able to create a saliency map by applying GRADCAM.

Can someone help me understand why I could be potentially getting zero gradients for image embeddings?