Rewriting generate function for manual decoder input

Hi @patrickvonplaten,
This is what I was looking for, thank you! Although your code works well with T5 (I have tried summarization as well), it does not seem to work with bart and pegasus, since when running:

from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

# OR
'''
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large')
model = PegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
'''

input_ids = tokenizer(text, return_tensors="pt").input_ids

decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids

output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)

print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

output = model.generate(input_ids, num_beams=4, num_return_sequences=4)

print("Without decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

I get the following error:

TypeError                                 Traceback (most recent call last)

<ipython-input-38-271e60997201> in <module>()
      2 decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids
      3 
----> 4 output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)
      5 
      6 print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

2 frames

/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     24         def decorate_context(*args, **kwargs):
     25             with self.__class__():
---> 26                 return func(*args, **kwargs)
     27         return cast(F, decorate_context)
     28 

/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, num_return_sequences, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, **model_kwargs)
    610                 pad_token_id=pad_token_id,
    611                 eos_token_id=eos_token_id,
--> 612                 **model_kwargs,
    613             )
    614 

/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, max_length, pad_token_id, eos_token_id, **model_kwargs)
   1041 
   1042         while cur_len < max_length:
-> 1043             model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   1044 
   1045             outputs = self(**model_inputs, return_dict=True)

TypeError: prepare_inputs_for_generation() got multiple values for argument 'decoder_input_ids'

Do you observe the same behaviour or am I missing something here?

1 Like