The way to get Seq2SeqLM's `decoder_input_ids` to obtain `past_key_values`

Sorry for the daily questions.

I’m working on Seq2SeqLM and thanks to the advice I’ve gotten on this forum, I’m able to move forward with understanding and implementation. Thank you so much!

Now, I would like to get past (or, past_key_values) in Seq2SeqLM as well as CausalLM.

past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) – List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_heads, sequence_length, embed_size_per_head)). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see past_key_values input) to speed up sequential decoding.

At first, I was having trouble getting past_key_values even though I confirm use_cache = true, but then I read the class BartModel's implementation of forward and realized I need decoder_input_ids for inputs.

        if decoder_input_ids is None:
            use_cache = False

Now I am wondering how to get the decoder_input_ids in Seq2seqLM.

I interpret as follows: CausalLM obtains past_key_values using the given context (or, prompt) and the output that comes from it in turn.

So how does Seq2SeqLM get the decoder_input_ids when it sequentially generates tokens with the given context after processing it with encoder?

Thank you in advance.

yusukemori