I am using a model = GPT2LMHeadModel() for generation. In my use case, Iāll need to call model.generate() for multiple times, and the input_ids have a shared prefix.
In my understanding, I could pass past_key_values as an argument in model.generate() so that it wouldnāt repeatedly compute the key, values of the shared prefix. However, how do I get this past_key_values? The generate() function returns a GreedySearchDecoderOnlyOutput object (I set beam size = 1, no sampling), which does not contain past_key_values. Itās only in the return of model.forward()?
So Iām wondering what would be a typical example of using past_key_values in multiple calls of model.generate(). Thanks!