T5 cross-attention - inconsistent results

Environment info

Python version: 3.7.10
PyTorch version (GPU?): '1.7.1+cu110' (True)
Transformer version: '4.5.0.dev0'

Details

I am trying to use the cross-attention from the T5 model for paraphrasing. The idea is to map the input sentence and output generated sequence based on the attention. But the first results I got are very strange.
I generated an example with the following code:

from transformers import T5ForConditionalGeneration, T5Tokenizer
import torch

pretrained_model = "ramsrigouthamg/t5_paraphraser"

model = T5ForConditionalGeneration.from_pretrained(pretrained_model,
                                                   output_attentions=True,
                                                   output_scores=True)

translated_sentence = "I like drinking Fanta and Cola."

text = "paraphrase: " + translated_sentence + " </s>"

encoding = tokenizer.encode_plus(text,
                                 pad_to_max_length=True,
                                 return_tensors="pt")
input_ids, attention_masks = encoding["input_ids"], encoding["attention_mask"].

Then, I gave a look to the cross attention for each generated token by selecting the last layer of the encoder and the first head.


beam_outputs = model.generate(
    input_ids=input_ids,
    attention_mask=attention_masks,
    do_sample=True,
    max_length=256,
    top_k=100,
    top_p=0.95,
    num_return_sequences=1,
    output_attentions = True,
    output_scores=True,
    return_dict_in_generate=True
)

sentence_id = 0
print("Input phrase: ", tokenizer.decode(encoding.input_ids[0],
                                      skip_special_tokens=False,
                                      clean_up_tokenization_spaces=False))
print("Predicted phrase: ", tokenizer.decode(beam_outputs.sequences[sentence_id],
                                      skip_special_tokens=True,
                                      clean_up_tokenization_spaces=True))
for out in range(len(beam_outputs.sequences[sentence_id])-1):
    print(
        "\nPredicted word: ",
        tokenizer.decode(beam_outputs.sequences[sentence_id][out],
                         skip_special_tokens=True,
                         clean_up_tokenization_spaces=True))
    att = torch.stack(beam_outputs.cross_attentions[out])
    # Last layer of the encoder
    att = att[-1]

   # First batch and first head
    att = att[0, 0, :, :]
    att = torch.squeeze(att)

    idx = torch.argsort(att)
    idx = idx.cpu().numpy()

    print("Input words ordered by attention: ")
    for i in range(min(5, len(idx))):
        token_smallest_attention =tokenizer.decode(encoding.input_ids[0][idx[i]],
                                      skip_special_tokens=True,
                                      clean_up_tokenization_spaces=True)
        token_largest_attention =tokenizer.decode(encoding.input_ids[0][idx[-(1+i)]],
                                      skip_special_tokens=True,
                                      clean_up_tokenization_spaces=True)
        print(f"{i+1}: Largest attention: {token_largest_attention} | smallest attention:{token_smallest_attention}")
    

The attention scores are sorted and each generated token is associated with the input with the highest attention (5 values) and with the lowest attentions (also 5 values).

Input phrase:  paraphrase: I like drinking Fanta and Cola.</s>
Predicted phrase:  I like to drink Fanta and Cola.

Predicted word:  <pad>
Input words ordered by attention: 
1: Largest attention: I | smallest attention:Col
2: Largest attention: like | smallest attention:a
3: Largest attention: : | smallest attention:t
4: Largest attention: para | smallest attention:a
5: Largest attention: . | smallest attention:Fan

Predicted word:  I
Input words ordered by attention: 
1: Largest attention: phrase | smallest attention:t
2: Largest attention: </s> | smallest attention:a
3: Largest attention: para | smallest attention:a
4: Largest attention: : | smallest attention:Col
5: Largest attention: like | smallest attention:and

Predicted word:  like
Input words ordered by attention: 
1: Largest attention: Fan | smallest attention:I
2: Largest attention: Col | smallest attention:.
3: Largest attention: phrase | smallest attention:like
4: Largest attention: a | smallest attention:para
5: Largest attention: </s> | smallest attention:a

Expecting results

I was expecting an almost one-to-one mapping as the paraphrase is very close to the input but it is not the case. The model gives good paraphrases. Do you think that I made some errors in the interpretation of the cross-attention object?

Thank you for your help!

Hopefully, it is something simple that I am missing.

We found the problem: the shape of the attention matrix was not correct.
Here are the modifications

# From T5 documentation
# Initial shape: Tuple (one element for each generated token) of tuples (one element for
# each layer of the decoder) of torch.FloatTensor of shape
# (batch_size, num_heads, generated_length, sequence_length).

# combine all cross attention into one tensor
x = [torch.stack(beam_outputs['cross_attentions'][i]) for i in range(len(beam_outputs['cross_attentions']))]
x = torch.stack(x)
# Shape: (nb_generated, nb_layer, (batch_size, num_heads, generated_length, sequence_length))
print(x.shape)

x = x.transpose(1,0)
# (nb_layer, nb_generated, batch_size, num_heads, generated_length, sequence_length)
print(x.shape)

x = x.transpose(1,3)
# (nb_layer, num_heads, batch_size, nb_generated, generated_length, sequence_length)
print(x.shape)

x = torch.squeeze(x, 4)
# (nb_layer, num_heads, batch_size, nb_generated, sequence_length)
print(x.shape)

x = x.transpose(2, 1)
# (nb_layer, batch_size, num_heads, nb_generated, sequence_length)
print(x.shape)
cross_attentions = x

We can compute the encoder/decoder tokens:

encoder_text = tokenizer.convert_ids_to_tokens(input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(beam_outputs.sequences[0])

encoder_tokens=np.array(encoder_text)
decoder_tokens=np.array(decoder_text[:-1])

The initial sentence is: In Belgium, summers can be very dry and the heat is burning. and the paraphrase: 'In Belgium the summers can be very dry and the heat is burning.
The associated tokens are:

['▁para', 'phrase', ':', '▁In', '▁Belgium', ',', '▁summer', 's', '▁can', '▁be', '▁very', '▁dry', '▁and', '▁the', '▁heat', '▁is', '▁burning', '.', '</s>']
['<pad>', '▁In', '▁Belgium', '▁the', '▁summer', 's', '▁can', '▁be', '▁very', '▁dry', '▁and', '▁the', '▁heat', '▁is', '▁burning', '.', '</s>']

Here are the attention score sorted by importance:

layer = 0
head = 0 # choose head to analyze
att = cross_attentions[layer, 0, head]
att.shape

for i in range(att.shape[0]):
    idx = np.argsort(att[i].cpu().numpy())[::-1][:6]
    print(f"Predicted token: {decoder_tokens[i]}, input related tokens: {encoder_tokens[idx]}")

Here are some results:

Predicted token: ▁In, input related tokens: ['</s>' '▁In' 'phrase' '.' ':' '▁Belgium']
Predicted token: ▁Belgium, input related tokens: ['▁Belgium' '</s>' 'phrase' ':' '▁para' '▁summer']
Predicted token: ▁the, input related tokens: ['</s>' 'phrase' ',' '.' '▁the' ':']
Predicted token: ▁summer, input related tokens: ['▁Belgium' '</s>' 'phrase' '▁In' 's' '▁summer']

We solved this problem with the help of the decoder encoder version of bertviz