Hi everyone! I am trying to build attention heat map for google T5 small model ( for translation task ). But, I find the structure of cross_attention quite misleading
cross_attentions = outputs.cross_attentions
Moreover, it’s not aligned with official documentation as well:
cross_attentions (tuple(torch.FloatTensor), optional, returned when output_attentions=True is passed or when config.output_attentions=True) — Tuple of torch.FloatTensor (one for each layer) of shape (batch_size, num_heads, sequence_length, sequence_length).
Attentions weights of the decoder’s cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads.
After my investigation the actual output of cross_attentions is tuple of len(num_ouput_tokens - 1). Each of cross_attentions[I] consist of [ num_heads ] (in case of small model - 6) tuples where each element is tensor of size (batch_size, 8, 1, num_of_input_tokens), where I find hard times to figure out what 8 actually is.
I would highly appreciate any explanation of how should I deal with this nested output to build heat maps similar to these: