Problem with returning decoder cross attentions through generate function

I’m using TFBartBase model with generate() with the default beam search setting. I have set the “output_attentions”, “return_dict”, and “return_dict_in_generate” flags to True. I noticed that the cross_attention object is an empty list. After a bit of digging around in beam_search_body_fn(), I found the below snippet that explains why this is happening.

The below snippet from tf_utils.py file.

 # Store scores, attentions and hidden_states when required
            if not use_xla and return_dict_in_generate:
                if output_scores:
                    all_scores.append(
                        logits_warper(
                            flatten_beam_dim(running_sequences),
                            flatten_beam_dim(log_probs_processed),
                            cur_len,
                        )
                    )
                if output_attentions and self.config.is_encoder_decoder:
                    decoder_attentions.append(model_outputs.decoder_attentions)
                elif output_attentions and not self.config.is_encoder_decoder:
                    decoder_attentions.append(model_outputs.attentions)
                    if self.config.is_encoder_decoder:
                        cross_attentions.append(model_outputs.cross_attentions)

why is cross_attentions.append(model_outputs.cross_attentions) not part of the first if output_attentions and self.config.is_encoder_decoder?
I believe this is causing cross_attentions object to not return anything. Is this a bug or is there a reason behind the if…elif… statement being formed the way it is?

I’d really appreciate soem insights. Thank you!

1 Like