I have a piece of code to accelerate text generation using past_key_values. The simplified version is as follows:
prefix_output = model(prefix_input_ids)
generation_output = model.generate(postfix_input_ids, num_beams=1, use_cache=True, past_key_values=prefix_output.past_key_values)
Here the variable âmodelâ can be âGPT2LMHeadModelâ that has loaded âgpt2-xlâ. The code works perfectly fine. The problem is that if ânum_beamsâ is set to greater than 1, then I get the exception below (in the example I set ânum_beamsâ to 3):
âSizes of tensors must match except in dimension 2. Expected size 1 but got size 3 for tensor number 1 in the list.â
I suspect that I should somehow pre-process the values of âprefix_output.past_key_valuesâ, before passing it to model.generate(). I am not sure though. Anybody knows how to fix this? Thanks.