Calculate Impact of Input Tokens on BERT Output Probability

Say I’ve trained a BERT model for classification. I’d like to calculate the proportional impact each input token is having on the predicted output.

For example - and this is very general - if I have a model that labels input text as {‘about dogs’ : 0, ‘about cats’ : 1}, the following input sentence:
s = 'this is a sentence about a cat'
should output very close to:
1

HOWEVER, what I’d like is to calculate each input’s impact on that final prediction, e.g. (assuming we’re tokenizing on the level of words - which is not how it would be done in practice, I know):
{this : .01, is: .005, a : .02, sentence : .0003, about : [some other low prob], a: [another low prob], cat : 0.999999}

Intuitively I’d think this means running a forward pass with the input sentence, then looking at the backprop values? But I’m not quite sure how you’d do that. Thoughts?

Hi @matthew,

I’ve looked into this a bit in the past and your intuitions are nearly spot on! What you have described is something like a saliency map. Here are some references that might be useful:

Taking the simplest case, what is effectively done is a forward pass, then a backward pass all the way to the input layer. The intuition is that will give you a sense of which tokens have the largest impact on the output if you were to change them by a small amount. In other words, if one of the input tokens from a sentence about cats is changed in such a way that the output of the model becomes 0 (i.e. a sentence about dogs), then that particular token would have a high saliency score.

Some of the papers above still were a bit unclear to me, so here are some supplementary references to hopefully get you on your way:

1 Like