Correct input_ids when passing past_key_values

The documentation (e.g. for mistral’s forward method Mistral) states that:

  • If past_key_values are used, the user can optionally input only the last input_ids (those that don’t have their past key value states given to this model) of shape (batch_size, 1) instead of all input_ids of shape (batch_size, sequence_length).

I wonder if it is true that it is ‘optional’ to input only the input_ids not present in the cache? As far as I can tell from scanning the code, there isn’t any obvious mechanism to handle the case where input_ids are present in the cache. So is the documentation correct? Or is it obligatory to pass only the input_ids not present in the cache?

Hey! In case you’re directly calling forward method, it’s obligatory to pass only the last input token, the one which is after cache tokens. In the codebase it’s handled within Attention blocks, where keys/values from cache and the new one are concatenated to form a long sequence. Note that you still have to pass the whole attn mask, not only the last one.

In case you’re calling generate() method, you should pass the whole sequence of input ids and attn masks, as we need to infer the actual length of the input.

1 Like

This topic was automatically closed 12 hours after the last reply. New replies are no longer allowed.