Using Attention matrix to explain a classification problem?


I am currently working on fine-tuning a multi-label, multi-class sequence classification model where each sequence is classified as belonging to one or more pre-defined set of classes (a sequence can also belong to multiple classes at the same time, which is slightly different from a standard multi-class problem).

I was wondering whether the attention matrices of Bert, DistilBert or similar NLP encoder-only transformer-based models, can be used for explaining to lay people the model’s choices. In particular I am inspired by the use of these matrices to explain seq-to-seq models such as machine translation (something like this), where the attention matrix is used to explain which tokens of the source language are mainly considered when generating each token of the target language.

I know that for an encoder-only architecture like a sequence classification model we only have self-attention, not cross-attention matrices, but nonetheless I was wondering whether it’s possible to use them to explain and understand which tokens are more significant for the model when assigning a sequence to a certain class / classes. Any help in the right direction is much appreciated!