Control EncoderDecoderModel to generate tokens step by step

If you want to generate incrementally it should be something along the lines of. Read this as pseudocode :slight_smile:

generated_so_far = [start_token]
while eos_token not in generated_so_far:
    outputs = model(..., decoder_inputs_ids=generated_so_far, ...) # incrementally build up the decoder_inputs with your previous predictions. This way your prediction becomes dependent on both the encoder inputs and the previous outputs
    next_token_logits = outputs[0][0, -1, :] # this indexing depends on the model, but take the last hidden state of the last token
    filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
    next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) 
    generated_so_far = torch.cat((generated_so_far, next_token.unsqueeze(0)), dim=1)
1 Like