Evaluatation of the gradients of class probabilities and logits with respect to attention layer and hidden states

I want to compute the gradient of class probability and class logits with respect to inputs in transformer models.

I have read issues #8601 and #8747, as well as the related test in the repository, to develop the following function.


def get_attention_hidden_states_grad_logits(hf_model, input, class_index, layer_index):
    model_outputs = hf_model(**input)

    hidden_sattes =  model.get("hidden_states", [])
    attentions =  model.get("attentions", [])

    [hidden_states[i].retain_grad() for i in range(len(hidden_states))]
    [attentions[i].retain_grad() for i in range(len(attentions))]

    class_logits = model_outputs.get('logits')
    class_probabilities = model_outputs.get('logits').softmax(dim=-1)

    class_logits.flatten()[logits_index].backward(retain_graph=True)

    hidden_states_grad = hidden_states[grad_index].grad
    attention_grads = attentions[grad_index].grad

    return hidden_states_grad, attention_grads, class_probabilities, class_logits

Question 1:

Is it necessary to apply the retain_grad method on all hidden_sattes and attentions? I think due to using the chain rule, using the following code would be enough (using the retain_grad method just for hidden states and attention of the first layer.)


def get_attention_hidden_states_grad_logits(hf_model, input, class_index, layer_index):
    model_outputs = hf_model(**input)

    hidden_sattes =  model.get("hidden_states", [])
    attentions =  model.get("attentions", [])

    hidden_states[0].retain_grad()
    attentions[0].retain_grad()

    class_logits = model_outputs.get('logits')
    class_probabilities = model_outputs.get('logits').softmax(dim=-1)

    class_logits.flatten()[logits_index].backward(retain_graph=True)

    hidden_states_grad = hidden_states[grad_index].grad
    attention_grads = attentions[grad_index].grad

    return hidden_states_grad, attention_grads, class_probabilities, class_logits

Question 2:

I’m wondering if it’s possible to calculate the gradient of class probabilities. Specifically, I’d like to know if the following code, which substitutes class probabilities with class logits, can be used to evaluate the gradient of class probabilities with respect to attentions and hidden states.

def get_attention_hidden_states_grad_probs(hf_model, input, class_index, layer_index):    model_outputs = hf_model(**input)
    hidden_sattes =  model.get("hidden_states", [])
    attentions =  model.get("attentions", [])

    [hidden_states[i].retain_grad() for i in range(len(hidden_states))]    
    [attentions[i].retain_grad() for i in range(len(attentions))]

    class_logits = model_outputs.get('logits')
    class_probabilities = model_outputs.get('logits').softmax(dim=-1)
    class_class_probabilities.flatten()[logits_index].backward(retain_graph=True)

    hidden_states_grad = hidden_states[grad_index].grad    
    attention_grads = attentions[grad_index].grad

    return hidden_states_grad, attention_grads, class_probabilities, class_logits

The gradients of attentions and hidden states change, which is surprising to me because the gradient of the attentions and hidden states with respect to inputs should not change as the intermediate variables when you change the output variables.

Question 3:

It appears that the following code can also evaluate the gradients of outputs (class probabilities or class logits) with respect to attention and hidden states. Can you please confirm if this is correct?

torch.autograd.grad(model_outputs.logits.flatten()[0], model_outputs.attentions[0], create_graph=True)
torch.autograd.grad(model_outputs.logits.flatten()[0], model_outputs.hidden_states[0], create_graph=True)

torch.autograd.grad(model_outputs.logits.softmax(dim=-1).flatten()[0], model_outputs.attentions[0], create_graph=True)
torch.autograd.grad(model_outputs.logits.softmax(dim=-1).flatten()[0], model_outputs.hidden_states[0], create_graph=True)

@joeddav @patrickvonplaten
Could you please confirm if I have made any errors? It would be tremendously helpful if you could provide me with the correct code to compute the gradient of attentions and hidden states. This is an essential aspect of my projects, and I greatly appreciate your assistance.

Side Note:

I know that the gradient of attention and hidden states does not change when we change the class_index; It is indeed for another purpose and future tasks.