Customizing GenerationMixin to output attentions

Hi all,
I’m using a Pegasus model (or really BartForConditionalGeneration since almost everything is inherited) and I’m interested in the attention outputs of various encoder and decoder blocks throughout the model.

Following the documentation, simply tokenizing an input context and running
model(**input_tokens, output_attentions = True)
allows me to dissect the attentions of each token in the input sequence in every layer (the dimensions being (batch_size, num_heads, seq_length, seq_length) for each layer). This is good.

Now I want to see the attentions that lead to the predictions of each token returned from model.generate(). Since I am on master and build from source, I edit the method _generate_beam_search in the GenerationMixin class in transformers/generation_utils.py to also pass the arg output_attentions = True. Here is the edited snippet:

decoder_attentions = [] # added this line

        while cur_len < max_length:
            model_inputs = self.prepare_inputs_for_generation(
                input_ids, past=past, attention_mask=attention_mask, use_cache=use_cache, **model_kwargs
            )
            outputs = self(**model_inputs, return_dict=True, output_attentions = True)  # edited this line
            decoder_attentions.append(outputs['decoder_attentions']) # added this line
            next_token_logits = outputs.logits[:, -1, :]  # (batch_size * num_beams, vocab_size)

As you can see, I try to get the decoder attentions for every iteration, append it, and later I pass up the list and return it from the main generate() method.

The problem is these attentions are no longer of the shape (batch_size, num_heads, seq_length, seq_length). Specifically, I would expect the sequence lengths to be at least the input context sequence length, but they are much shorter, and do not even match the length of the final sequence prediction.

I feel like I have some misunderstanding about how things are working. I know encoder_outputs and some hidden states are precomputed, but I don’t know if they are affecting this.
Can anyone help me understand what is going on here? Is there a way to see the influence of specific input context tokens in decoder block attention heads on predicted tokens ?

I think this just returning the attentions from the deocder since the encoder hidden states are only computed once in the generate function. So the returned attentions at every iteration is attention over the tokens generated so far.

pinging @sshleifer

@azhx what shape are the attentions?

What’s going on is called “incremental decoding/caching”. Bart is reusing all the information it can from the previous decoder step and trying to only do “math” on the latest decoder token.

This all happens before attn_weights are computed (all done by https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_bart.py#L698), so I would expect your decoder self attention weights to be of shape ` (batch_size, num_heads, cur_len, cur_len) at each step.

Note that cross attention is never returned. I have a very experimental working branch that concatenates the cross attn outputs with the self attention outputs and returns them all as one big tensor: https://github.com/huggingface/transformers/pull/6967

Thanks for helping!

So from what you’re saying, attn_weights are not necessary for this decoding step? I thought they were used compute the output of the attention layers in decoder blocks, which are part of the math required in the decoding step and so would need to be computed before the tokens are fed through?

I also thought the dimensions would be (batch_size, num_heads, cur_len, cur_len) but in one of the examples I tried, the attention weights shape seemed to be
(batch_size, num heads, 1 , cur_len)

That dimension is always one. For instance, the final index (index 37) of the decoder_attentions list in my example gave dimensions (8, 16, 1, 37). I suppose 8 (the batch_size) is due to the number of beams, 16 is clearly the num heads, and the final dimension corresponds with the index of the list (so cur_len) but I don’t understand the 1. The other trouble I mentioned was that 37 was not the length of the final best sequence prediction returned by generate() - that was 23.
@sshleifer

37 should be the length of the longest hypothesis considered by beam search, not necessarily the best.

The shape of the weights should be (batch_size, num heads, query_len, key_len).
Since the decoder processes 1 new query token at a time this,
your shapes make sense.

What I’m not sure about is whether after your best hypothesis is “finalized” in beam search, it will keep being passed to the model. I suspect it will not.

You might be able to recover the attention weights of a finalized hypothesis more easily by calling

best_generation = model.generate(src_tokens)
outputs = model(src_tokens, labels=best_generation, output_attentions=True, return_dict=True)
outputs.decoder_attentions