Hi all,
I’m using a Pegasus
model (or really BartForConditionalGeneration
since almost everything is inherited) and I’m interested in the attention outputs of various encoder and decoder blocks throughout the model.
Following the documentation, simply tokenizing an input context and running
model(**input_tokens, output_attentions = True)
allows me to dissect the attentions of each token in the input sequence in every layer (the dimensions being (batch_size, num_heads, seq_length, seq_length) for each layer). This is good.
Now I want to see the attentions that lead to the predictions of each token returned from model.generate(). Since I am on master
and build from source, I edit the method _generate_beam_search
in the GenerationMixin
class in transformers/generation_utils.py
to also pass the arg output_attentions = True
. Here is the edited snippet:
decoder_attentions = [] # added this line
while cur_len < max_length:
model_inputs = self.prepare_inputs_for_generation(
input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
)
outputs = self(**model_inputs, return_dict=True, output_attentions = True) # edited this line
decoder_attentions.append(outputs['decoder_attentions']) # added this line
next_token_logits = outputs.logits[:, -1, :] # (batch_size * num_beams, vocab_size)
As you can see, I try to get the decoder attentions for every iteration, append it, and later I pass up the list and return it from the main generate() method.
The problem is these attentions are no longer of the shape (batch_size, num_heads, seq_length, seq_length). Specifically, I would expect the sequence lengths to be at least the input context sequence length, but they are much shorter, and do not even match the length of the final sequence prediction.
I feel like I have some misunderstanding about how things are working. I know encoder_outputs
and some hidden states are precomputed, but I don’t know if they are affecting this.
Can anyone help me understand what is going on here? Is there a way to see the influence of specific input context tokens in decoder block attention heads on predicted tokens ?