Hi, I am trying to achieve this translation attention with MarianMT model as tf tutorial
Basically is to tell which word corresponding to the generated word.
I am not sure if I have use the correct field in transformers output.
Here is the core code
from transformers import MarianMTModel, MarianTokenizer
import numpy as np
class MarianZH():
def __init__(self):
model_name = 'Helsinki-NLP/opus-mt-en-zh'
self.tokenizer = MarianTokenizer.from_pretrained(model_name)
print(self.tokenizer.supported_language_codes)
self.model = MarianMTModel.from_pretrained(model_name)
def input_format(self,en_text):
if type(en_text)==list:
# use batch
src_text=[]
for i in en_text:
src_text.append(">>cmn_Hans<< "+i)
elif type(en_text)==str:
src_text=[
'>>cmn_Hans<< '+en_text,
]
else:
raise TypeError("Unsupported type of {}".format(en_text))
return src_text
def get_attention_weight(self,en_text):
src_text=self.input_format(en_text)
batch = self.tokenizer.prepare_seq2seq_batch(src_text)
tensor_output=self.model(batch['input_ids'],return_dict=True,output_attentions=True)
attention_weights=tensor_output.cross_attentions[-1].detach()
batch_size, attention_heads,input_seq_length,output_seq_length=attention_weights.shape
translated = self.model.generate(**batch)
for i in range(batch_size):
attention_weight_i=attention_weights[i,:,:,:].reshape(attention_heads,input_seq_length,output_seq_length)
cross_weight=np.sum(attention_weight_i.numpy(),axis=0) # cross weight
yield cross_weight
if __name__ == '__main__':
src_text = [
'>>cmn_Hans<< Thai food is delicious.',
]
mdl=MarianZH()
attention_weight=mdl.get_attention_weight(src_text)
btw. I am using transformers==3.5.1
Is this cross_weight the attention matrix corresponding to translation attention? But the output seems to be always focus on first columns or last columns.