Big `generate()` refactor

Soon we will merge a big refactor of the generate() method:
All transformer models that have a language model head rely on the generate() method, e.g.

  • Bart, T5, Marian, ProphetNet for summarization, translation, …
  • GPT2, XLNet for open-end text generation

The generate() method has become hard to read and more and quite difficult to maintain. The goal of the refactor is thus as follows:

  1. Remove the _no_beam_search_generation and _beam_search_generation methods and replace them with the four different generation methods corresponding to the four different types of generation: greedy_search (corresponds to num_beams = 1, do_sample=False), sample (corresponds to num_beams = 1, do_sample=True), beam_search (corresponds to num_beams > 1, do_sample = False), and beam_sample (corresponds to num_beams > 1, do_sample = True).
    The following philosophy has been adapted here: the original generate() method is kept at 100% backwards compatibility (ping me on github under @patrickvonplaten, if your specific use case has nevertheless been broken by this PR). In addition each of the four specific generation methods can be used directly. The specific methods are as “bare-bone” as possible meaning that they don’t contain any magic pre-processing of input tensors. E.g. generate() automatically runs the input_ids through the encoder if model is encoder-decoder model vs. for each specific generate method they have to be added as encoder_outputs to the model_kwargs; input_ids are automatically created in case they are empty or automatically expanded for num_beams in generate() vs. this has to be done manually for each specific generate method. The reason behind this design is that it will give the user much more flexibility for specific use cases of generate() and improves maintainability and readability. The user should not be limited in any way when directly using the specific generate methods. It should therefore pave the way to allow backprop through generate, make beam generation much easier, makes it easier to write “higher level” generate functions as is done for RAG… For more information on this design please read the docs, look into the examples of greedy_search, sample, beam_search and beam_sample.

  2. All of the generate parameters that can be used to tweak the logits distribution for better generation results, e.g. no_repeat_ngram_size, min_length, … are now defined as separate classes that are added to a LogitsProcessorList. This has the following advantages: a) better readability b) much easier to test these functions c) easier to add new logits distribution warpers. A huge thanks goes to who has had the original idea of this design in this PR:

  3. Move all beam_search relevant code into its own file and speed up beam search. Beam search has gained more and more in importance thanks to many new and improved seq2seq models. This PR moves the very difficult to understand beam search code into its own file and makes sure that the beam_search generate function is easier to understand this way. Additionally, all Python List operations are now replaced by torch.tensor operations which led to a 5-10% speed up for beam_search generation. This change improves speed, readability and also maintainability. New beam search algorithms, such as , should now be easier to add to the library.

  4. Tests have been refactored and new more aggressive tests have been added. The tests can also be very helpful to understand how each of the methods work exactly. Check out, test_generation_beam_search and test_generation_logits_process.

  5. More docstring especially on beam search and logits processors to make generate more accessible and understandable to the user.


  • Do the same refactor for TF
  • Check possibility of carrying gradients through generate
  • Add GenerationOutputs similar to ModelOutputs that allows to return attention outputs and hidden states

Hi Patrick @patrickvonplaten,

Thanks for your great work!! I have a question.
I now almost finish translating TFDPR , and currently also working on TFRag (I already made TFRagModel works like Pytorch, and I am investigating TFRagTokenForGeneration right now) .

Since the current code of RagTokenForGeneration involves with both _no_beam_search_generation and _beam_search_generation , regarding the Refactor, I think I should just go ahead with the current version, and fix them later for refactor.

Or would you suggest that I wait for Pytorch Refactor, or wait for both Pytorch & TF Refactor ?

That’s amazing to hear that you are working on the TFRagModel ! do you already have a PR on this?

I’d definitely say that you should go ahead with _no_beam_search_generation and _beam_search_generation in TF (it’ll take probably at least ~2,3 weeks until I start the TF refactor and > a month until I merge it). Feel free to tag me on the TFRagModel :slight_smile:

1 Like

Hi Patrick, thank you for your suggestion.

as Rag uses DPR’s pretrained weights as a submodule, so I made a TFDPR PR first. Now it’s around 25% progress on TFRag … I hope I can make a TFRag PR before EMNLP , and I will seek your suggestions on TFRag for sure :smiley:

Just in case you are interested in, my current draft with TFRagModel is here

Hello – on a hopefully related note, I saw in a recent Huggingface email that you had implemented Diverse Beam Search. I was wondering, how is this applied and how can I use it?

I have been using beam search with sampling in a summarization pipeline (via Bart-Large-cnn), and I was wondering how I can add more diversity to each beam generated? Thanks!

Hi you can use the new diverse beam generation easily by just adding the two new parameters into generate().
Two new parameters on diverse beam are num_beam_group and diversity_penalty where num_beam_group <= num_beams


outputs = model.generate(input_ids, min_length=1, max_length=200,
                                  no_repeat_ngram_size=3, num_beams=5, early_stopping=True,
                                  do_sample=False, num_return_sequences=5,
                                  num_beam_groups=5, diversity_penalty=2.0

It only works with do_sample=False however

I just knew about diverse beam search, I have to try it for my abstractive text summarizer but unfortunately it works mostly extractive. I hope this diverse beam search could help to create abstractive summary little bit