I want to compute the gradient of class probability and class logits with respect to inputs in transformer models.

I have read issues #8601 and #8747, as well as the related test in the repository, to develop the following function.

```
def get_attention_hidden_states_grad_logits(hf_model, input, class_index, layer_index):
model_outputs = hf_model(**input)
hidden_sattes = model.get("hidden_states", [])
attentions = model.get("attentions", [])
[hidden_states[i].retain_grad() for i in range(len(hidden_states))]
[attentions[i].retain_grad() for i in range(len(attentions))]
class_logits = model_outputs.get('logits')
class_probabilities = model_outputs.get('logits').softmax(dim=-1)
class_logits.flatten()[logits_index].backward(retain_graph=True)
hidden_states_grad = hidden_states[grad_index].grad
attention_grads = attentions[grad_index].grad
return hidden_states_grad, attention_grads, class_probabilities, class_logits
```

#### Question 1:

Is it necessary to apply the `retain_grad`

method on all hidden_sattes and attentions? I think due to using the chain rule, using the following code would be enough (using the `retain_grad`

method just for hidden states and attention of the first layer.)

```
def get_attention_hidden_states_grad_logits(hf_model, input, class_index, layer_index):
model_outputs = hf_model(**input)
hidden_sattes = model.get("hidden_states", [])
attentions = model.get("attentions", [])
hidden_states[0].retain_grad()
attentions[0].retain_grad()
class_logits = model_outputs.get('logits')
class_probabilities = model_outputs.get('logits').softmax(dim=-1)
class_logits.flatten()[logits_index].backward(retain_graph=True)
hidden_states_grad = hidden_states[grad_index].grad
attention_grads = attentions[grad_index].grad
return hidden_states_grad, attention_grads, class_probabilities, class_logits
```

#### Question 2:

Iâ€™m wondering if itâ€™s possible to calculate the gradient of class probabilities. Specifically, Iâ€™d like to know if the following code, which substitutes class probabilities with class logits, can be used to evaluate the gradient of class probabilities with respect to attentions and hidden states.

```
def get_attention_hidden_states_grad_probs(hf_model, input, class_index, layer_index): model_outputs = hf_model(**input)
hidden_sattes = model.get("hidden_states", [])
attentions = model.get("attentions", [])
[hidden_states[i].retain_grad() for i in range(len(hidden_states))]
[attentions[i].retain_grad() for i in range(len(attentions))]
class_logits = model_outputs.get('logits')
class_probabilities = model_outputs.get('logits').softmax(dim=-1)
class_class_probabilities.flatten()[logits_index].backward(retain_graph=True)
hidden_states_grad = hidden_states[grad_index].grad
attention_grads = attentions[grad_index].grad
return hidden_states_grad, attention_grads, class_probabilities, class_logits
```

The gradients of attentions and hidden states change, which is surprising to me because the gradient of the attentions and hidden states with respect to inputs should not change as the intermediate variables when you change the output variables.

#### Question 3:

It appears that the following code can also evaluate the gradients of outputs (class probabilities or class logits) with respect to attention and hidden states. Can you please confirm if this is correct?

```
torch.autograd.grad(model_outputs.logits.flatten()[0], model_outputs.attentions[0], create_graph=True)
torch.autograd.grad(model_outputs.logits.flatten()[0], model_outputs.hidden_states[0], create_graph=True)
torch.autograd.grad(model_outputs.logits.softmax(dim=-1).flatten()[0], model_outputs.attentions[0], create_graph=True)
torch.autograd.grad(model_outputs.logits.softmax(dim=-1).flatten()[0], model_outputs.hidden_states[0], create_graph=True)
```

@joeddav @patrickvonplaten

Could you please confirm if I have made any errors? It would be tremendously helpful if you could provide me with the correct code to compute the gradient of attentions and hidden states. This is an essential aspect of my projects, and I greatly appreciate your assistance.

#### Side Note:

**I know that the gradient of attention and hidden states does not change when we change the class_index; It is indeed for another purpose and future tasks**.