How to get translation with attention using MarianMT

Hi, I am trying to achieve this translation attention with MarianMT model as tf tutorial


Basically is to tell which word corresponding to the generated word.
I am not sure if I have use the correct field in transformers output.
Here is the core code
from transformers import MarianMTModel, MarianTokenizer
import numpy as np

class MarianZH():
    def __init__(self):
        model_name = 'Helsinki-NLP/opus-mt-en-zh'
        self.tokenizer = MarianTokenizer.from_pretrained(model_name)
        print(self.tokenizer.supported_language_codes)
        self.model = MarianMTModel.from_pretrained(model_name)

    def input_format(self,en_text):
        if type(en_text)==list:
            # use batch
            src_text=[]
            for i in en_text:
                src_text.append(">>cmn_Hans<< "+i)
        elif type(en_text)==str:
            src_text=[
                '>>cmn_Hans<< '+en_text,
            ]
        else:
            raise TypeError("Unsupported type of {}".format(en_text))
        return src_text


    def get_attention_weight(self,en_text):
        src_text=self.input_format(en_text)
        batch = self.tokenizer.prepare_seq2seq_batch(src_text)
        tensor_output=self.model(batch['input_ids'],return_dict=True,output_attentions=True)
        attention_weights=tensor_output.cross_attentions[-1].detach()
        batch_size, attention_heads,input_seq_length,output_seq_length=attention_weights.shape
        translated = self.model.generate(**batch)
        for i in range(batch_size):
            attention_weight_i=attention_weights[i,:,:,:].reshape(attention_heads,input_seq_length,output_seq_length)
            cross_weight=np.sum(attention_weight_i.numpy(),axis=0) # cross weight
            yield cross_weight

if __name__ == '__main__':
    src_text = [
        '>>cmn_Hans<< Thai food is delicious.',
        ]
    mdl=MarianZH()
    attention_weight=mdl.get_attention_weight(src_text)

btw. I am using transformers==3.5.1

Is this cross_weight the attention matrix corresponding to translation attention? But the output seems to be always focus on first columns or last columns.

I’m not sure about if this attention weight matrix is accessible in MarianMT model or not. As the structure of the MarianMT is different in contrast to the one in TF tutorial. If anyone could tell if this task is possible please?

AFAIK Marian is implemented as a subclass of BART for generation, which outputs cross attention weights:

cc @patrickvonplaten