Rewriting generate function for manual decoder input

Hello everybody, I am trying to reproduce the generate function of the GenerationMixin class to be able to give manual decoder input. I am using transformers v4.1.1. While I get nice results using the greedy_search function, I am not managing to reproduce the beam_search one, since my RAM overflows. I do not have memory problems using generate. Hereafter is the code. I am not using any special decoder input for now, only model.config.bos_token_id. My plan was to check everything worked and change it afterwards. I have tested this with both bart-base and pegasus-large, with equal results.

from transformers.generation_utils import GenerationMixin
gm = GenerationMixin

min_length = config.BULLETS_MIN_LEN
max_length = config.BULLETS_MAX_LEN
num_beams = 4
early_stopping = True
no_repeat_ngram_size = 5
num_return_sequences = 1

model_kwargs = {}

pad_token_id = model.config.pad_token_id
eos_token_id = model.config.eos_token_id
decoder_start_token_id = model.config.decoder_start_token_id
bos_token_id = model.config.bos_token_id

# encode the text
input_ids = tokenizer.encode(text, return_tensors='pt')

# prepare attention mask and encoder output
model_kwargs["attention_mask"] = gm._prepare_attention_mask_for_generation(
    model, input_ids, pad_token_id, eos_token_id)

if model.config.is_encoder_decoder:
    model_kwargs = gm._prepare_encoder_decoder_kwargs_for_generation(model, input_ids, model_kwargs)

    input_ids = gm._prepare_decoder_input_ids_for_generation(
        model, input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs)

model_kwargs["use_cache"] = None

logits_processor = gm._get_logits_processor(
    model,
    repetition_penalty=None,
    no_repeat_ngram_size=no_repeat_ngram_size,
    bad_words_ids=None,
    min_length=min_length,
    eos_token_id=None,
    prefix_allowed_tokens_fn=None,
    num_beams=num_beams,
    num_beam_groups=None,
    diversity_penalty=None)

Using greedy_search:

outputs = gm.greedy_search(
    model,
    input_ids,
    logits_processor=logits_processor,
    max_length=max_length,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    **model_kwargs)

print(tokenizer.decode(outputs[0], skip_special_tokens = True))

And the results is equal to

real_outputs = model.generate(tokenizer.encode(
    df_cc_group.iloc[0].text[0], return_tensors='pt'),
    min_length = min_length,
    max_length = max_length,
    no_repeat_ngram_size = no_repeat_ngram_size,
    num_beams = 1)

print(tokenizer.decode(real_outputs[0], skip_special_tokens = True))

However, using beam_search my RAM overflows:

from transformers import BeamSearchScorer

batch_size = input_ids.shape[0]

length_penalty = model.config.length_penalty
early_stopping = early_stopping

beam_scorer = BeamSearchScorer(
    batch_size=batch_size,
    max_length=max_length,
    num_beams=num_beams,
    device=model.device,
    length_penalty=length_penalty,
    do_early_stopping=early_stopping,
    num_beam_hyps_to_keep=num_return_sequences)

input_ids, model_kwargs = gm._expand_inputs_for_generation(
    input_ids,
    expand_size=4,
    is_encoder_decoder=model.config.is_encoder_decoder,
    **model_kwargs)

outputs = gm.beam_search(
    model,
    input_ids,
    beam_scorer,
    logits_processor=logits_processor,
    max_length=max_length,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    **model_kwargs)

Thank you in advance for the help! :slight_smile:

Hey @marcoabrate,

The current version of generate (and also the one of v4.1.1.) already includes the possibility to provide user specific decoder_input_ids. You just have to add it to generate().

E.g. the following code works as expected

from transformers import T5ForConditionalGeneration, T5TokenizerFast

model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5TokenizerFast.from_pretrained("t5-small")

input_ids = tokenizer("translate English to German: How are you?", return_tensors="pt").input_ids
decoder_input_ids = tokenizer("<pad> Wie geht", return_tensors="pt", add_special_tokens=False).input_ids

output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)

print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

output = model.generate(input_ids, num_beams=4, num_return_sequences=4)

print("Without decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True)) 

Also see this notebook for the answers of this specific use case: https://colab.research.google.com/drive/11js9We6ZtjN15hb3-PoFZBXJrcSOo_Qa?usp=sharing

This feature was enabled by the generate refactor: Big `generate()` refactor

Does this correspond to what you were trying to achieve or are you looking for some other behavior?

Hi @patrickvonplaten,
This is what I was looking for, thank you! Although your code works well with T5 (I have tried summarization as well), it does not seem to work with bart and pegasus, since when running:

from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

# OR
'''
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large')
model = PegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
'''

input_ids = tokenizer(text, return_tensors="pt").input_ids

decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids

output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)

print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

output = model.generate(input_ids, num_beams=4, num_return_sequences=4)

print("Without decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

I get the following error:

TypeError                                 Traceback (most recent call last)

<ipython-input-38-271e60997201> in <module>()
      2 decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids
      3 
----> 4 output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)
      5 
      6 print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

2 frames

/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     24         def decorate_context(*args, **kwargs):
     25             with self.__class__():
---> 26                 return func(*args, **kwargs)
     27         return cast(F, decorate_context)
     28 

/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, num_return_sequences, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, **model_kwargs)
    610                 pad_token_id=pad_token_id,
    611                 eos_token_id=eos_token_id,
--> 612                 **model_kwargs,
    613             )
    614 

/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, max_length, pad_token_id, eos_token_id, **model_kwargs)
   1041 
   1042         while cur_len < max_length:
-> 1043             model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   1044 
   1045             outputs = self(**model_inputs, return_dict=True)

TypeError: prepare_inputs_for_generation() got multiple values for argument 'decoder_input_ids'

Do you observe the same behaviour or am I missing something here?

Yeah, that’s definitely a bug! Would you mind opening an issue about it on transformers? Here is the link: https://github.com/huggingface/transformers/issues/new/choose -> feel free to ping me (@patrickvonplaten) with a link to this thread here and I’ll take care of it asap. Should be pretty trivial to solve :slight_smile:

Thanks for taking a deeper look into this feature - this should definitely be fully supported by Transformers

Thank you for your fast reply, here a link to the GitHub issue: #9400

1 Like

Should be fixed as soon as [Generation] Fix bug for manual decoder_input_ids + warning message by patrickvonplaten · Pull Request #9472 · huggingface/transformers · GitHub is merged. Thanks a lot for reporting this issue and making the library better :slight_smile:

1 Like