Soon we will merge a big refactor of the
generate() method: https://github.com/huggingface/transformers/pull/6949
All transformer models that have a language model head rely on the
generate() method, e.g.
- Bart, T5, Marian, ProphetNet for summarization, translation, …
- GPT2, XLNet for open-end text generation
generate() method has become hard to read and more and quite difficult to maintain. The goal of the refactor is thus as follows:
_beam_search_generationmethods and replace them with the four different generation methods corresponding to the four different types of generation:
greedy_search(corresponds to num_beams = 1, do_sample=False),
sample(corresponds to num_beams = 1, do_sample=True),
beam_search(corresponds to num_beams > 1, do_sample = False), and
beam_sample(corresponds to num_beams > 1, do_sample = True).
The following philosophy has been adapted here: the original
generate()method is kept at 100% backwards compatibility (ping me on github under @patrickvonplaten, if your specific use case has nevertheless been broken by this PR). In addition each of the four specific generation methods can be used directly. The specific methods are as “bare-bone” as possible meaning that they don’t contain any magic pre-processing of input tensors. E.g.
generate()automatically runs the
input_idsthrough the encoder if model is encoder-decoder model vs. for each specific generate method they have to be added as
input_idsare automatically created in case they are empty or automatically expanded for
generate()vs. this has to be done manually for each specific generate method. The reason behind this design is that it will give the user much more flexibility for specific use cases of
generate()and improves maintainability and readability. The user should not be limited in any way when directly using the specific generate methods. It should therefore pave the way to allow backprop through generate, make beam generation much easier, makes it easier to write “higher level” generate functions as is done for RAG… For more information on this design please read the docs, look into the examples of
All of the generate parameters that can be used to tweak the logits distribution for better generation results, e.g.
min_length, … are now defined as separate classes that are added to a
LogitsProcessorList. This has the following advantages: a) better readability b) much easier to test these functions c) easier to add new logits distribution warpers. A huge thanks goes to https://github.com/turtlesoupy who has had the original idea of this design in this PR: https://github.com/huggingface/transformers/pull/5420
beam_searchrelevant code into its own generation_beam_search.py file and speed up beam search. Beam search has gained more and more in importance thanks to many new and improved seq2seq models. This PR moves the very difficult to understand beam search code into its own file and makes sure that the
beam_searchgenerate function is easier to understand this way. Additionally, all Python List operations are now replaced by torch.tensor operations which led to a 5-10% speed up for
beam_searchgeneration. This change improves speed, readability and also maintainability. New beam search algorithms, such as https://arxiv.org/abs/2007.03909 , should now be easier to add to the library.
Tests have been refactored and new more aggressive tests have been added. The tests can also be very helpful to understand how each of the methods work exactly. Check out
More docstring especially on beam search and logits processors to make generate more accessible and understandable to the user.
- Do the same refactor for TF
- Check possibility of carrying gradients through generate
ModelOutputsthat allows to return attention outputs and hidden states