Forge synthetic past_key_value batch from multiple outputs

Hello everyone,

I am working on an algorithm which use GPT-2 LMHead to get logits over the dictionnary given an input sequence. In order to speed up the process, I am batching the input to process multiple sequences at once.
I saw that, to speed up the generation even more, I can use past_key_value as input of my next iteration so we do not go over already calculated keys/values. This works pretty well when I’m doing fully sequential generation (i.e, the next input is the previous output), but in my algorithm, the output is not necessarly the next input, but might be used n steps after (or might even never be used as input).

E.g :
Step 1 : input is [“Hello HuggingFace my name”]
Step 2 : input is [“Hello HuggingFace, how”]
Step 3 : input is [“Hello HuggingFace my name is”]

In this case, I would like to use past_key_values returned on step 1 as input in step 3.
Since it is batched and each output element in the batch can be used the next time at a different timestep, I want to store each keys/values states next to each element.

While even that is kind of complicated since past_key_values is a tuple of tuple with shape (16, 2, batch_size), converting it to a Numpy array make it possible.
However, I am now struggling to forge my input using what I am storing since each element can have different length. While for input_ids, it is easy to forge the input given sequences with different length by padding (fill it with pad_token and use attention_mask), past_key_values element also have different size since it depends on the size of the input that created it because we have key/values for each input token (the actual size of past_key_values for an element is [12, input_size, 64]).

Is it possible to pad this ? Does it even make sense ?

I am sorry if this is unclear, it is pretty hard to explain so don’t hesitate to ask me anything that might help to clarify.

Thank you in advance !