I am not sure if I missed an obvious way to do this, but I didn’t find any.
Basically the idea is that if we have a seq2seq model, let’s say Bart. Right now, one can input the tokens to the encoder in order to start decoding and generating text using model.generate()
, but there doesn’t seem to be a way to add decoder inputs, that is text which we want the generate function to continue.
Using the example at the documentation:
from transformers import BartTokenizer, BartForConditionalGeneration, BartConfig
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
ARTICLE_TO_SUMMARIZE = "My friends are cool but they eat too many carbs."
inputs = tokenizer([ARTICLE_TO_SUMMARIZE], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
And this outputs ['My friends']
. Let’s say we want to precondition the generation to start with ‘Friends’ instead of ‘My’, it would be cool to have something like:
decoder_inputs = tokenizer(['Friends'], max_length=1024, return_tensors='pt')
# Generate Summary
summary_ids = model.generate(inputs['input_ids'], decoder_inputs = decoder_inputs['inputs_ids'] num_beams=4, max_length=5, early_stopping=True)
print([tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summary_ids])
Which would output a summary starting with ‘Friends’. I am aware that it is possible to do a forward pass with explicit decoder_inputs
but I was wondering if there is a way to do this at generation, to take advantage of beam search and such.
Perhaps with the newly added prefix_allowed_tokens_fn
there is a workaround by having it return the desired starting tokens at the beginning of generation but I was wondering if there is a more straight forward way I missed or is this something that would be interesting to add to the generate