Finding gradients in zero-shot learning

First off, the zero-shot module is amazing. It wraps up a lot of boiler-plate that I’ve been using into a nice succinct interface. With that however, I’m having trouble getting the gradients of intermediate layers. Let’s take an example:

from transformers import pipeline
import torch

model_name = 'facebook/bart-large-mnli'
nlp = pipeline("zero-shot-classification", model=model_name)

responses = ["I'm having a great day!!"]
hypothesis_template = 'This person feels {}'
candidate_labels = ['happy', 'sad']
nlp(responses, candidate_labels, hypothesis_template=hypothesis_template)

This works well! The output is:

{'sequence': "I'm having a great day!!",
 'labels': ['happy', 'sad'],
 'scores': [0.9989933371543884, 0.0010066736722365022]}

What I’d like to do however, is look at the gradients of the input tokens to see which tokens are important. This is in contrast to looking at the attention heads (which is also another viable tactic). Trying to rip apart the internals of the module, I can get the logics and embedding layers:

inputs = nlp._parse_and_tokenize(responses, candidate_labels, hypothesis_template)
predictions = nlp.model(**inputs, return_dict=True, output_hidden_states=True)
predictions['logits']

tensor([[-3.1864, -0.0714,  3.2625],
        [ 4.5919, -1.9473, -3.6376]], grad_fn=<AddmmBackward>)

This is expected, as the label for “happy” is index 0 and the entailment index for this model is 2, so the value of 3.2625 is an extremely strong signal. The label for “sad” is 1 and the contradiction index is 0, so the value of 4.5919 is also the correct answer.

Great! Now I should be able to look at the first embedding layer and check out the gradient with respect to the happy entailment scalar:

layer = predictions['encoder_hidden_states'][0]
layer.retain_grad()
predictions['logits'][0][2].backward(retain_graph=True)

Unfortunately, layer.grad is None. I’ve tried almost everything I can think of, and now I’m a bit stuck. Thanks for the help!

I’ve reproduced this but not sure if I have a good answer – looks like more of a Bart/PyTorch question rather than something specific to the zero shot pipeline. Maybe @patrickvonplaten would have an idea?

Thanks for looking @joeddav! I’ve been spending all day trying to figure this out and I think that the issue is somewhere between the encoder layer and the BART model itself. Looking at the source code it looks like calling forward pulls the input into an encoded representation and that is passed to the BART model. These encoder_hidden_states are disconnected from the main model when I backprop. I’m wondering if the huggingface implementation decouples it somewhere for performance reasons (that would make sense, you shouldn’t need to backprop this far normally when training).

Hopefully @patrickvonplaten can weigh in with some insight. If not, I’ll try pytorch forums and SO.

@joeddav @patrickvonplaten I spent my hard earned internet points on SO to find the problem:

It turns out that calling the transpose:

encoder_states = tuple(hidden_state.transpose(0, 1) for hidden_state in encoder_states)

kills the backprop for some reason. I’m not sure why this is, but it solved my problem. It might be nice to not have to monkey patch this to get the gradients.

That’s really strange, not sure I understand why that’s the case. Can you open a GitHub issue and link this topic as well as your SO thread?