Generate without using the generate method

Posting this here for visibility. What if you want to decode the output of a generative seq2seq model (like T5, BART, etc.) yourself, without using the .generate() method? The code example below illustrates this.

Suppose that the model is given a long text, for which it needs to generate a summary. We illustrate here how to manually decode the generated ids autoregressively. In each iteration, we add the predicted token id by the model to the decoder_input_ids, which are then fed as input to the next time step. At the beginning, we only feed the decoder_start_token_id to the decoder of the model.

from transformers import BartTokenizer, BartForConditionalGeneration
import torch

model_name = "sshleifer/distilbart-cnn-6-6"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name)

text = """The tower is 324 metres (1,063 ft) tall, about the same height as an 81-storey building, and the tallest structure in Paris. Its base is square, measuring 125 metres (410 ft) on each side. During its construction, the Eiffel Tower surpassed the Washington Monument to become the tallest man-made structure in the world, a title it held for 41 years until the Chrysler Building in New York City was finished in 1930. It was the first structure to reach a height of 300 metres. Due to the addition of a broadcasting aerial at the top of the tower in 1957, it is now taller than the Chrysler Building by 5.2 metres (17 ft). Excluding transmitters, the Eiffel Tower is the second tallest free-standing structure in France after the Millau Viaduct."""

input_ids = tokenizer(text, return_tensors="pt").input_ids

decoder_input_ids = [model.config.decoder_start_token_id]
predicted_ids = []
for i in range(20): 
    outputs = model(input_ids=input_ids, decoder_input_ids=torch.tensor([decoder_input_ids]))
    logits = outputs.logits[:,i,:]
    # perform argmax on the last dimension (i.e. greedy decoding)
    predicted_id = logits.argmax(-1)
    predicted_ids.append(predicted_id.item())
    print(tokenizer.decode([predicted_id.squeeze()]))
    # add predicted id to decoder_input_ids
    decoder_input_ids = decoder_input_ids + [predicted_id]

This will print:

The
 E
iff
el
 Tower
 is
 324
 metres
 (
1
,
06
3
 ft
)
 tall
,
 about
 the
 same

The final result can also be printed using print(tokenizer.decode(predicted_ids)):

The Eiffel Tower is 324 metres (1,063 ft) tall, about the same

Note that we’ve only done 20 time steps here. Normally, one continues until the model generates the EOS (end of sequence) token, which for BART is </s>.

5 Likes

Hi Niels, thanks for sharing the code. Would you mind also sharing some examples of situations in which you would prefer not to use the .generate() method?

If you are deploying your model on triton server and you are inferencing through triton client there is no generate method for your help, I used this method to decode my output through the model.

2 Likes

Let us suppose we want to restrict our vocabulary to some specific set of tokens (that changes dynamically with each time step). What is the best way of incorporating that? Other than decoding each token individually?

Thanks for the post. I’m finding this to be much slower than the generate( ) function for my use case (whisper model). Is this expected?

That might be because this doesn’t cache the hidden states when generating, if I understand correctly. You would need to keep past_key_values or something like that by making sure use_cache is True in your model config.

Otherwise in the above snippet you’re re-computing the entire past sequence every time you want a next token, despite the fact that causal attention means all the past hidden states are constant.

This may help a lot. What if the decoding is using beam search?