Setting "num_beams" and using "past_key_values" when calling .generate()

I have a piece of code to accelerate text generation using past_key_values. The simplified version is as follows:

prefix_output = model(prefix_input_ids)
generation_output = model.generate(postfix_input_ids, num_beams=1, use_cache=True, past_key_values=prefix_output.past_key_values)

Here the variable “model” can be “GPT2LMHeadModel” that has loaded “gpt2-xl”. The code works perfectly fine. The problem is that if “num_beams” is set to greater than 1, then I get the exception below (in the example I set “num_beams” to 3):

‘Sizes of tensors must match except in dimension 2. Expected size 1 but got size 3 for tensor number 1 in the list.’

I suspect that I should somehow pre-process the values of “prefix_output.past_key_values”, before passing it to model.generate(). I am not sure though. Anybody knows how to fix this? Thanks.