Beam search (FlaxT5) generates PAD tokens mid generation

Hi,

When I try to use beam search (num_beams > 1) with FlaxT5ForConditionalGeneration, the model generates PAD tokens randomly in the middle of the generation and the output is not usually complete.

Is this a known issue?

I have seen the FlaxMarianMTModel overrides _adapt_logits_for_beam_search method to fix a similar issue. Is something similar required for FlaxT5ForConditionalGeneration as well? If someone can guide me, I can look at implementing it and getting a PR up.

Thank you!

I have also encountered this issue recently. I used the official t5_summarization_flax.py, but ran it on an internal dataset. I get good results using greedy search, but much worse results when using num_beams > 1. On some examples it generates the whole sequence, but on others it starts generating pad tokens mid-sequence, without outputting an EOS token.

I tried monkeypatching the MarianMT _adapt_logits_for_beam_search onto the FlaxT5Model, as the method description seemed like it was implemented to solve this type of issue, but that didn’t seem to change the outputs for T5 at all.

Any thoughts on why this might be happening, or how we could resolve it for Flax T5 beam search @patrickvonplaten / @patil-suraj?