Why reshaping attn_weights when outputting attentions?

In the bart source code, when the BartAttention class is to output attention weights, the weights are reshaped twice to “keep its gradient”, I wonder why this operation is necessary because attn_weights are in the same shape before this operation.

    if output_attentions:
        # this operation is a bit akward, but it's required to
        # make sure that attn_weights keeps its gradient.
        # In order to do so, attn_weights have to reshaped
        # twice and have to be reused in the following
        attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
        attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
    else:
        attn_weights_reshaped = None
1 Like