Why past_key_values is not in GreedySearchDecoderOnlyOutput?

I am using a model = GPT2LMHeadModel() for generation. In my use case, I’ll need to call model.generate() for multiple times, and the input_ids have a shared prefix.

In my understanding, I could pass past_key_values as an argument in model.generate() so that it wouldn’t repeatedly compute the key, values of the shared prefix. However, how do I get this past_key_values? The generate() function returns a GreedySearchDecoderOnlyOutput object (I set beam size = 1, no sampling), which does not contain past_key_values. It’s only in the return of model.forward()?

So I’m wondering what would be a typical example of using past_key_values in multiple calls of model.generate(). Thanks!

1 Like

Never mind. There’s a PR for it.