Sorry for the daily questions.
I’m working on
Seq2SeqLM and thanks to the advice I’ve gotten on this forum, I’m able to move forward with understanding and implementation. Thank you so much!
Now, I would like to get
Seq2SeqLM as well as
past_key_values (List[torch.FloatTensor], optional, returned when use_cache=True is passed or when config.use_cache=True) – List of torch.FloatTensor of length config.n_layers, with each tensor of shape (2, batch_size, num_heads, sequence_length, embed_size_per_head)). Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be used (see past_key_values input) to speed up sequential decoding.
At first, I was having trouble getting
past_key_values even though I confirm
use_cache = true, but then I read the
class BartModel's implementation of
forward and realized I need
decoder_input_ids for inputs.
if decoder_input_ids is None: use_cache = False
Now I am wondering how to get the
I interpret as follows:
past_key_values using the given context (or, prompt) and the output that comes from it in turn.
So how does
Seq2SeqLM get the
decoder_input_ids when it sequentially generates tokens with the given context after processing it with encoder?
Thank you in advance.