Rewriting generate function for manual decoder input

Hello everybody, I am trying to reproduce the generate function of the GenerationMixin class to be able to give manual decoder input. I am using transformers v4.1.1. While I get nice results using the greedy_search function, I am not managing to reproduce the beam_search one, since my RAM overflows. I do not have memory problems using generate. Hereafter is the code. I am not using any special decoder input for now, only model.config.bos_token_id. My plan was to check everything worked and change it afterwards. I have tested this with both bart-base and pegasus-large, with equal results.

from transformers.generation_utils import GenerationMixin
gm = GenerationMixin

min_length = config.BULLETS_MIN_LEN
max_length = config.BULLETS_MAX_LEN
num_beams = 4
early_stopping = True
no_repeat_ngram_size = 5
num_return_sequences = 1

model_kwargs = {}

pad_token_id = model.config.pad_token_id
eos_token_id = model.config.eos_token_id
decoder_start_token_id = model.config.decoder_start_token_id
bos_token_id = model.config.bos_token_id

# encode the text
input_ids = tokenizer.encode(text, return_tensors='pt')

# prepare attention mask and encoder output
model_kwargs["attention_mask"] = gm._prepare_attention_mask_for_generation(
    model, input_ids, pad_token_id, eos_token_id)

if model.config.is_encoder_decoder:
    model_kwargs = gm._prepare_encoder_decoder_kwargs_for_generation(model, input_ids, model_kwargs)

    input_ids = gm._prepare_decoder_input_ids_for_generation(
        model, input_ids, decoder_start_token_id=decoder_start_token_id, bos_token_id=bos_token_id, **model_kwargs)

model_kwargs["use_cache"] = None

logits_processor = gm._get_logits_processor(
    model,
    repetition_penalty=None,
    no_repeat_ngram_size=no_repeat_ngram_size,
    bad_words_ids=None,
    min_length=min_length,
    eos_token_id=None,
    prefix_allowed_tokens_fn=None,
    num_beams=num_beams,
    num_beam_groups=None,
    diversity_penalty=None)

Using greedy_search:

outputs = gm.greedy_search(
    model,
    input_ids,
    logits_processor=logits_processor,
    max_length=max_length,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    **model_kwargs)

print(tokenizer.decode(outputs[0], skip_special_tokens = True))

And the results is equal to

real_outputs = model.generate(tokenizer.encode(
    df_cc_group.iloc[0].text[0], return_tensors='pt'),
    min_length = min_length,
    max_length = max_length,
    no_repeat_ngram_size = no_repeat_ngram_size,
    num_beams = 1)

print(tokenizer.decode(real_outputs[0], skip_special_tokens = True))

However, using beam_search my RAM overflows:

from transformers import BeamSearchScorer

batch_size = input_ids.shape[0]

length_penalty = model.config.length_penalty
early_stopping = early_stopping

beam_scorer = BeamSearchScorer(
    batch_size=batch_size,
    max_length=max_length,
    num_beams=num_beams,
    device=model.device,
    length_penalty=length_penalty,
    do_early_stopping=early_stopping,
    num_beam_hyps_to_keep=num_return_sequences)

input_ids, model_kwargs = gm._expand_inputs_for_generation(
    input_ids,
    expand_size=4,
    is_encoder_decoder=model.config.is_encoder_decoder,
    **model_kwargs)

outputs = gm.beam_search(
    model,
    input_ids,
    beam_scorer,
    logits_processor=logits_processor,
    max_length=max_length,
    pad_token_id=pad_token_id,
    eos_token_id=eos_token_id,
    **model_kwargs)

Thank you in advance for the help! :slight_smile:

1 Like

Hey @marcoabrate,

The current version of generate (and also the one of v4.1.1.) already includes the possibility to provide user specific decoder_input_ids. You just have to add it to generate().

E.g. the following code works as expected

from transformers import T5ForConditionalGeneration, T5TokenizerFast

model = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5TokenizerFast.from_pretrained("t5-small")

input_ids = tokenizer("translate English to German: How are you?", return_tensors="pt").input_ids
decoder_input_ids = tokenizer("<pad> Wie geht", return_tensors="pt", add_special_tokens=False).input_ids

output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)

print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

output = model.generate(input_ids, num_beams=4, num_return_sequences=4)

print("Without decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True)) 

Also see this notebook for the answers of this specific use case: https://colab.research.google.com/drive/11js9We6ZtjN15hb3-PoFZBXJrcSOo_Qa?usp=sharing

This feature was enabled by the generate refactor: Big `generate()` refactor

Does this correspond to what you were trying to achieve or are you looking for some other behavior?

1 Like

Hi @patrickvonplaten,
This is what I was looking for, thank you! Although your code works well with T5 (I have tried summarization as well), it does not seem to work with bart and pegasus, since when running:

from transformers import BartTokenizer, BartForConditionalGeneration
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

# OR
'''
from transformers import PegasusTokenizer, PegasusForConditionalGeneration
tokenizer = PegasusTokenizer.from_pretrained('google/pegasus-large')
model = PegasusForConditionalGeneration.from_pretrained('google/pegasus-large')
'''

input_ids = tokenizer(text, return_tensors="pt").input_ids

decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids

output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)

print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

output = model.generate(input_ids, num_beams=4, num_return_sequences=4)

print("Without decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

I get the following error:

TypeError                                 Traceback (most recent call last)

<ipython-input-38-271e60997201> in <module>()
      2 decoder_input_ids = tokenizer("<s> Anatomy is", return_tensors="pt", add_special_tokens=False).input_ids
      3 
----> 4 output = model.generate(input_ids, decoder_input_ids=decoder_input_ids, num_beams=4, num_return_sequences=4)
      5 
      6 print("With decoder_input_ids num_beams=4", tokenizer.batch_decode(output, skip_special_tokens=True))

2 frames

/usr/local/lib/python3.6/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     24         def decorate_context(*args, **kwargs):
     25             with self.__class__():
---> 26                 return func(*args, **kwargs)
     27         return cast(F, decorate_context)
     28 

/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, num_return_sequences, decoder_start_token_id, use_cache, num_beam_groups, diversity_penalty, prefix_allowed_tokens_fn, **model_kwargs)
    610                 pad_token_id=pad_token_id,
    611                 eos_token_id=eos_token_id,
--> 612                 **model_kwargs,
    613             )
    614 

/usr/local/lib/python3.6/dist-packages/transformers/generation_utils.py in beam_search(self, input_ids, beam_scorer, logits_processor, max_length, pad_token_id, eos_token_id, **model_kwargs)
   1041 
   1042         while cur_len < max_length:
-> 1043             model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   1044 
   1045             outputs = self(**model_inputs, return_dict=True)

TypeError: prepare_inputs_for_generation() got multiple values for argument 'decoder_input_ids'

Do you observe the same behaviour or am I missing something here?

1 Like

Yeah, that’s definitely a bug! Would you mind opening an issue about it on transformers? Here is the link: https://github.com/huggingface/transformers/issues/new/choose -> feel free to ping me (@patrickvonplaten) with a link to this thread here and I’ll take care of it asap. Should be pretty trivial to solve :slight_smile:

Thanks for taking a deeper look into this feature - this should definitely be fully supported by Transformers

Thank you for your fast reply, here a link to the GitHub issue: #9400

1 Like

Should be fixed as soon as [Generation] Fix bug for manual decoder_input_ids + warning message by patrickvonplaten · Pull Request #9472 · huggingface/transformers · GitHub is merged. Thanks a lot for reporting this issue and making the library better :slight_smile:

1 Like

@marcoabrate Forgive my ignorance (and for hijacking your old thread), but I’m not totally sure of the utility of using decoder_input_ids, and wonder if it might help with my problem.

I’m doing sentence infilling with Bart and want to ensure that a relatively short passage of tokens are infilled for each <mask>. As it is, I’m noticing that it will generally do the infilling okay, but it sometimes won’t “reconnect” with the following context properly, and can generate a lot more content than expected. With decoder_input_ids could I perhaps give a kind of “schematic” of the output like

original = "Sentence with several magical words arranged in such a way as to promote a fulfilling existence."
text_input = "Sentence with <mask> arranged in <mask> promote a fulfilling existence."
decoder_schematic = "Sentence with <pad> arranged in <pad> promote a fulfilling existence."

where the “decoder_schematic” would be used for the decoder_input_ids? My thinking is that perhaps this would make the task really explicit, rather than just relying on the model’s training to ensure adherence to the original structure.

Does that make sense?

okay, I actually just tried something like this and it clearly doesn’t work… :rofl:

UPDATE: I’d actually forgotten that decoder_input_ids is the “labels” tensor during training, though I’m still not sure what it does for generation. I tried making a clone of my input, swapping <mask> with <pad> and rotating with torch.roll but it still clearly doesn’t do what I was thinking… all it seems to do is force decoder_token_ids as the output (which makes sense, I suppose).