Sorry for the daily questions.
I’m working on Seq2SeqLM
and thanks to the advice I’ve gotten on this forum, I’m able to move forward with understanding and implementation. Thank you so much!
Now, I would like to get past
(or, past_key_values
) in Seq2SeqLM
as well as CausalLM
.
past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) – List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_heads, sequence_length, embed_size_per_head)). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see past_key_values input) to speed up sequential decoding.
At first, I was having trouble getting past_key_values
even though I confirm use_cache = true
, but then I read the class BartModel
’s implementation of forward
and realized I need decoder_input_ids
for inputs.
if decoder_input_ids is None:
use_cache = False
Now I am wondering how to get the decoder_input_ids
in Seq2seqLM
.
I interpret as follows: CausalLM
obtains past_key_values
using the given context (or, prompt) and the output that comes from it in turn.
So how does Seq2SeqLM
get the decoder_input_ids
when it sequentially generates tokens with the given context after processing it with encoder?
Thank you in advance.
yusukemori