Soon we will merge a big refactor of the generate()
method: https://github.com/huggingface/transformers/pull/6949
All transformer models that have a language model head rely on the generate()
method, e.g.
- Bart, T5, Marian, ProphetNet for summarization, translation, …
- GPT2, XLNet for open-end text generation
The generate()
method has become hard to read and more and quite difficult to maintain. The goal of the refactor is thus as follows:
-
Remove the
_no_beam_search_generation
and_beam_search_generation
methods and replace them with the four different generation methods corresponding to the four different types of generation:greedy_search
(corresponds to num_beams = 1, do_sample=False),sample
(corresponds to num_beams = 1, do_sample=True),beam_search
(corresponds to num_beams > 1, do_sample = False), andbeam_sample
(corresponds to num_beams > 1, do_sample = True).
The following philosophy has been adapted here: the originalgenerate()
method is kept at 100% backwards compatibility (ping me on github under @patrickvonplaten, if your specific use case has nevertheless been broken by this PR). In addition each of the four specific generation methods can be used directly. The specific methods are as “bare-bone” as possible meaning that they don’t contain any magic pre-processing of input tensors. E.g.generate()
automatically runs theinput_ids
through the encoder if model is encoder-decoder model vs. for each specific generate method they have to be added asencoder_outputs
to themodel_kwargs
;input_ids
are automatically created in case they are empty or automatically expanded fornum_beams
ingenerate()
vs. this has to be done manually for each specific generate method. The reason behind this design is that it will give the user much more flexibility for specific use cases ofgenerate()
and improves maintainability and readability. The user should not be limited in any way when directly using the specific generate methods. It should therefore pave the way to allow backprop through generate, make beam generation much easier, makes it easier to write “higher level” generate functions as is done for RAG… For more information on this design please read the docs, look into the examples ofgreedy_search
,sample
,beam_search
andbeam_sample
. -
All of the generate parameters that can be used to tweak the logits distribution for better generation results, e.g.
no_repeat_ngram_size
,min_length
, … are now defined as separate classes that are added to aLogitsProcessorList
. This has the following advantages: a) better readability b) much easier to test these functions c) easier to add new logits distribution warpers. A huge thanks goes to https://github.com/turtlesoupy who has had the original idea of this design in this PR: https://github.com/huggingface/transformers/pull/5420 -
Move all
beam_search
relevant code into its own generation_beam_search.py file and speed up beam search. Beam search has gained more and more in importance thanks to many new and improved seq2seq models. This PR moves the very difficult to understand beam search code into its own file and makes sure that thebeam_search
generate function is easier to understand this way. Additionally, all Python List operations are now replaced by torch.tensor operations which led to a 5-10% speed up forbeam_search
generation. This change improves speed, readability and also maintainability. New beam search algorithms, such as https://arxiv.org/abs/2007.03909 , should now be easier to add to the library. -
Tests have been refactored and new more aggressive tests have been added. The tests can also be very helpful to understand how each of the methods work exactly. Check out
test_generation_utils.py
,test_generation_beam_search
andtest_generation_logits_process
. -
More docstring especially on beam search and logits processors to make generate more accessible and understandable to the user.
TODO:
- Do the same refactor for TF
- Check possibility of carrying gradients through generate
- Add
GenerationOutputs
similar toModelOutputs
that allows to return attention outputs and hidden states