Hi @patrickvonplaten,
This is what I was looking for, thank you! Although your code works well with T5
(I have tried summarization as well), it does not seem to work with bart
and pegasus
, since when running:
from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
# OR
'''
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large')
model = PegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
'''
input_ids = tokenizer(text, return_tensors="pt").input_ids
decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids
output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)
print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))
output = model.generate(input_ids, num_beams=4, num_return_sequences=4)
print("Without decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))
I get the following error:
TypeError Traceback (most recent call last)
<ipython-input-38-271e60997201> in <module>()
2 decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids
3
----> 4 output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)
5
6 print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))
2 frames
/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
24 def decorate_context(*args, **kwargs):
25 with self.__class__():
---> 26 return func(*args, **kwargs)
27 return cast(F, decorate_context)
28
/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, num_return_sequences, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, **model_kwargs)
610 pad_token_id=pad_token_id,
611 eos_token_id=eos_token_id,
--> 612 **model_kwargs,
613 )
614
/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, max_length, pad_token_id, eos_token_id, **model_kwargs)
1041
1042 while cur_len < max_length:
-> 1043 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
1044
1045 outputs = self(**model_inputs, return_dict=True)
TypeError: prepare_inputs_for_generation() got multiple values for argument 'decoder_input_ids'
Do you observe the same behaviour or am I missing something here?