Control EncoderDecoderModel to generate tokens step by step

Hey, I am writing a model and the baseline is bert2bert for text summarization. But I want to add a specific layer above the Decoder. For example , I want to change the LMhead of Decoder by concatenating another vector. But the DecoderModel outputs all the hidden states at once. I want to control it for step by step decoding. In other words. I want to use the concatenated vector as the hidden state for generation and use the generated word vector for next step’s input. How can I change the model or call the interface properly ?

My expression may not be very clear. I want to say, in EncoderDeocderModel, I load the model like this

model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

I want to modify the structure of LMHead or manipulate the single step of the output hidden state of the decoder to make the use-specific generation . Is it possible ?

You can just run the model multiple times to generate a conditional sequence. You for example first give the start token as input_ids, make a prediction, add that token to the input_sequence and forward again. This will give you a pure conditionally generated sequence. It is however not very efficient as it does not make use of ā€˜cache’, i.e. it will compute the outputs for every time step, also if it has already outputted values for that time step in the previous generation step. You could fix that by adding a notion of cache to the code (returning previous key, values, which you can later add). GPT-2 implements this cache with variables names past and present.

@claartje-barkhof ,thank you for your response . Can you provide a example for bert2bert model ?
because I build the model like this

  model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased')

And when I run the forward method like this

outputs = model(input_ids=src, decoder_input_ids=dst, labels=dst, return_dict=True)

It outputs all the logits at once.
I don’t konw how to control it .

If you want to generate incrementally it should be something along the lines of. Read this as pseudocode :slight_smile:

generated_so_far = [start_token]
while eos_token not in generated_so_far:
    outputs = model(..., decoder_inputs_ids=generated_so_far, ...) # incrementally build up the decoder_inputs with your previous predictions. This way your prediction becomes dependent on both the encoder inputs and the previous outputs
    next_token_logits = outputs[0][0, -1, :] # this indexing depends on the model, but take the last hidden state of the last token
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 
    generated_so_far = torch.cat((generated_so_far, next_token.unsqueeze(0)), dim=1)
1 Like

Thank you very much !
I want to ask another question .
Do you know how much the speed will be reduced using this step by step method if the ā€˜cache’ mechanism is not implemented ?
It seems that every single generation step the whole source text will pass the encoder model . However, It only needs to pass once.

That very much depends on the size of the model, batch size, the sequence length you want/expect to generate and on your available resources.

I was trying to do without cache with a GPT2 the other day and had issues quite soon. Longer sequences took way too long: I guess already with batch size 8, sequence length 30, GPT2-base model it started to get really slow on my local machine. It depends also if you want to decode a lot of sequences at once etc. Just for trying a few example, you might be okay without the cache.

I would recommend reading this blog about cache to understand the concept and then check the code of GPT-2 to see how it is implemented.

Thank you !

@claartje-barkhof @guoziyuan Hi guys, almost two years later but this thread came in very handy!

Could you guys expand a bit more on how to perform caching? I read the Illustrated GPT-2 blog, but it would be great to have a more concrete example on how to achieve it. For reference, I’m using a BART model for summarization.