How to do batch generation with the GPT2 model?
Batch generation is now possible for GPT2 in master by leveraging the functionality shown in this PR: https://github.com/huggingface/transformers/pull/7552?notification_referrer_id=MDE4Ok5vdGlmaWNhdGlvblRocmVhZDEyMTMzNzA0MDA6MjM0MjM2MTk%3D#event-3876130796 .
For more info on how to prepare a GPT2 for batch generation, you can checkout this test:
Hi I am the author of the PR.
You can now do batch generation by calling the same generate()
.
All you need to add is:
- set
tokenizer.padding_side = "left"
(probably reset it back later) - pass in
attention_mask
togenerate()
Explanation: (see full example in the end)
- We need
tokenizer.padding_side = "left"
because we will use the logits of the right-most token to predict the next token, so the padding should be on the left. - This what this PR added. Here is a summary:
GPT-2 uses absolute positional embedding (position_ids
), before this change, no position_ids
is passed in to the model, and the model automatically generates the embeddings from 0 to n, even if there is padding (e.g. when input is a batch).
Example: tokens=<pad> <pad> a b c
-> position_ids=0 1 2 3 4
, what we expect is x x 0 1 2
(x
means don’t case)
This PR adds positional embedding in prepare_inputs_for_generation()
, which is called in generate()
, by calculating them using
attention_mask
, and that’s why you need to pass it in.
You can find a full example in PR.
Hi, there. Thanks for your work to support batch inference in GPT2. However, I still have one confusion, which may need your help. Thanks in advance!
If I wanna pass the “past_key_values”, how should I process the position_ids and attention mask? Supposing the length of my past_key_values is 2, the padded input is just like your example: , , a, b, c. Should I change the attention mask from 0, 0, 1, 1, 1 to 1, 1, 0, 0, 1, 1, 1, where the first double “1” refers. to the past_key_values.
Thanks a lot!
@patrickvonplaten @ttj I think this is a good question! Could we discuss on how to do batch inference with past_key_values
?