The documentation (e.g. for mistralâs forward method Mistral) states that:
- If
past_key_values
are used, the user can optionally input only the lastinput_ids
(those that donât have their past key value states given to this model) of shape(batch_size, 1)
instead of allinput_ids
of shape(batch_size, sequence_length)
.
I wonder if it is true that it is âoptionalâ to input only the input_ids not present in the cache? As far as I can tell from scanning the code, there isnât any obvious mechanism to handle the case where input_ids are present in the cache. So is the documentation correct? Or is it obligatory to pass only the input_ids not present in the cache?