Control EncoderDecoderModel to generate tokens step by step

Hey, I am writing a model and the baseline is bert2bert for text summarization. But I want to add a specific layer above the Decoder. For example , I want to change the LMhead of Decoder by concatenating another vector. But the DecoderModel outputs all the hidden states at once. I want to control it for step by step decoding. In other words. I want to use the concatenated vector as the hidden state for generation and use the generated word vector for next step’s input. How can I change the model or call the interface properly ?

My expression may not be very clear. I want to say, in EncoderDeocderModel, I load the model like this

model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

I want to modify the structure of LMHead or manipulate the single step of the output hidden state of the decoder to make the use-specific generation . Is it possible ?

You can just run the model multiple times to generate a conditional sequence. You for example first give the start token as input_ids, make a prediction, add that token to the input_sequence and forward again. This will give you a pure conditionally generated sequence. It is however not very efficient as it does not make use of ‘cache’, i.e. it will compute the outputs for every time step, also if it has already outputted values for that time step in the previous generation step. You could fix that by adding a notion of cache to the code (returning previous key, values, which you can later add). GPT-2 implements this cache with variables names past and present.

@claartje-barkhof ,thank you for your response . Can you provide a example for bert2bert model ?
because I build the model like this

  model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

And when I run the forward method like this

outputs = model(input_ids=src, decoder_input_ids=dst, labels=dst, return_dict=True)

It outputs all the logits at once.
I don’t konw how to control it .

If you want to generate incrementally it should be something along the lines of. Read this as pseudocode :slight_smile:

generated_so_far = [start_token]
while eos_token not in generated_so_far:
    outputs = model(..., decoder_inputs_ids=generated_so_far, ...) # incrementally build up the decoder_inputs with your previous predictions. This way your prediction becomes dependent on both the encoder inputs and the previous outputs
    next_token_logits = outputs[0][0, -1, :] # this indexing depends on the model, but take the last hidden state of the last token
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 
    generated_so_far =, next_token.unsqueeze(0)), dim=1)
1 Like

Thank you very much !
I want to ask another question .
Do you know how much the speed will be reduced using this step by step method if the ‘cache’ mechanism is not implemented ?
It seems that every single generation step the whole source text will pass the encoder model . However, It only needs to pass once.

That very much depends on the size of the model, batch size, the sequence length you want/expect to generate and on your available resources.

I was trying to do without cache with a GPT2 the other day and had issues quite soon. Longer sequences took way too long: I guess already with batch size 8, sequence length 30, GPT2-base model it started to get really slow on my local machine. It depends also if you want to decode a lot of sequences at once etc. Just for trying a few example, you might be okay without the cache.

I would recommend reading this blog about cache to understand the concept and then check the code of GPT-2 to see how it is implemented.

Thank you !

@claartje-barkhof @guoziyuan Hi guys, almost two years later but this thread came in very handy!

Could you guys expand a bit more on how to perform caching? I read the Illustrated GPT-2 blog, but it would be great to have a more concrete example on how to achieve it. For reference, I’m using a BART model for summarization.