Generate 'continuation' for seq2seq models

I am not sure if I missed an obvious way to do this, but I didn’t find any.

Basically the idea is that if we have a seq2seq model, let’s say Bart. Right now, one can input the tokens to the encoder in order to start decoding and generating text using model.generate(), but there doesn’t seem to be a way to add decoder inputs, that is text which we want the generate function to continue.

Using the example at the documentation:

from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

And this outputs ['My friends']. Let’s say we want to precondition the generation to start with ‘Friends’ instead of ‘My’, it would be cool to have something like:

decoder_inputs = tokenizer(['Friends'], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], decoder_inputs = decoder_inputs['inputs_ids']  num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])

Which would output a summary starting with ‘Friends’. I am aware that it is possible to do a forward pass with explicit decoder_inputs but I was wondering if there is a way to do this at generation, to take advantage of beam search and such.

Perhaps with the newly added prefix_allowed_tokens_fn there is a workaround by having it return the desired starting tokens at the beginning of generation but I was wondering if there is a more straight forward way I missed or is this something that would be interesting to add to the generate functionality.
Cheers!

1 Like

So after looking at the code behind generate() indeed it already incorporates this functionality if decoder_input_ids is given as input, in a similar fashion as the forward function:

summary_ids = model.generate(inputs['input_ids'], decoder_input_ids = decoder_inputs['inputs_ids']  num_beams=4, max_length=5, early_stopping=True)

So once again I am impressed by the amount of capabilities built into the library. I still think it would be beneficial to clarify this functionality somewhere in the documentation.

I won’t delete the post in case someone comes looking for the same.

6 Likes