So after looking at the code behind generate()
indeed it already incorporates this functionality if decoder_input_ids
is given as input, in a similar fashion as the forward
function:
summary_ids = model.generate(inputs['input_ids'], decoder_input_ids = decoder_inputs['inputs_ids'] num_beams=4, max_length=5, early_stopping=True)
So once again I am impressed by the amount of capabilities built into the library. I still think it would be beneficial to clarify this functionality somewhere in the documentation.
I won’t delete the post in case someone comes looking for the same.