Soon we will merge a big refactor of the generate() method: https://github.com/huggingface/transformers/pull/6949
All transformer models that have a language model head rely on the generate() method, e.g.
- Bart, T5, Marian, ProphetNet for summarization, translation, …
- GPT2, XLNet for open-end text generation
The generate() method has become hard to read and more and quite difficult to maintain. The goal of the refactor is thus as follows:
-
Remove the
_no_beam_search_generationand_beam_search_generationmethods and replace them with the four different generation methods corresponding to the four different types of generation:greedy_search(corresponds to num_beams = 1, do_sample=False),sample(corresponds to num_beams = 1, do_sample=True),beam_search(corresponds to num_beams > 1, do_sample = False), andbeam_sample(corresponds to num_beams > 1, do_sample = True).
The following philosophy has been adapted here: the originalgenerate()method is kept at 100% backwards compatibility (ping me on github under @patrickvonplaten, if your specific use case has nevertheless been broken by this PR). In addition each of the four specific generation methods can be used directly. The specific methods are as “bare-bone” as possible meaning that they don’t contain any magic pre-processing of input tensors. E.g.generate()automatically runs theinput_idsthrough the encoder if model is encoder-decoder model vs. for each specific generate method they have to be added asencoder_outputsto themodel_kwargs;input_idsare automatically created in case they are empty or automatically expanded fornum_beamsingenerate()vs. this has to be done manually for each specific generate method. The reason behind this design is that it will give the user much more flexibility for specific use cases ofgenerate()and improves maintainability and readability. The user should not be limited in any way when directly using the specific generate methods. It should therefore pave the way to allow backprop through generate, make beam generation much easier, makes it easier to write “higher level” generate functions as is done for RAG… For more information on this design please read the docs, look into the examples ofgreedy_search,sample,beam_searchandbeam_sample. -
All of the generate parameters that can be used to tweak the logits distribution for better generation results, e.g.
no_repeat_ngram_size,min_length, … are now defined as separate classes that are added to aLogitsProcessorList. This has the following advantages: a) better readability b) much easier to test these functions c) easier to add new logits distribution warpers. A huge thanks goes to https://github.com/turtlesoupy who has had the original idea of this design in this PR: https://github.com/huggingface/transformers/pull/5420 -
Move all
beam_searchrelevant code into its own generation_beam_search.py file and speed up beam search. Beam search has gained more and more in importance thanks to many new and improved seq2seq models. This PR moves the very difficult to understand beam search code into its own file and makes sure that thebeam_searchgenerate function is easier to understand this way. Additionally, all Python List operations are now replaced by torch.tensor operations which led to a 5-10% speed up forbeam_searchgeneration. This change improves speed, readability and also maintainability. New beam search algorithms, such as https://arxiv.org/abs/2007.03909 , should now be easier to add to the library. -
Tests have been refactored and new more aggressive tests have been added. The tests can also be very helpful to understand how each of the methods work exactly. Check out
test_generation_utils.py,test_generation_beam_searchandtest_generation_logits_process. -
More docstring especially on beam search and logits processors to make generate more accessible and understandable to the user.
TODO:
- Do the same refactor for TF
- Check possibility of carrying gradients through generate
- Add
GenerationOutputssimilar toModelOutputsthat allows to return attention outputs and hidden states

