I am using a model = GPT2LMHeadModel()
for generation. In my use case, Iâll need to call model.generate()
for multiple times, and the input_ids
have a shared prefix.
In my understanding, I could pass past_key_values
as an argument in model.generate()
so that it wouldnât repeatedly compute the key, values of the shared prefix. However, how do I get this past_key_values
? The generate()
function returns a GreedySearchDecoderOnlyOutput
object (I set beam size = 1, no sampling), which does not contain past_key_values
. Itâs only in the return of model.forward()
?
So Iâm wondering what would be a typical example of using past_key_values
in multiple calls of model.generate()
. Thanks!