I am currently wondering how to use the cache/get a valid cache position.
First of all my use case: I have an encoder decoder structure where the decoder has mamba2 backbone.
The general idea is that I want to encode my input, feed the input into the decoder and create a new sequence autoregressively.
As far as I understand I can now do something like this: (please correct me if I am wrong)
What I am wondering right now is how to get the cache_positions. And did I understand correctly that the cache is the internal state of the mamba2 model?
Any help would be appreciated.
Many thanks in advance
Your understanding of the caching mechanism in the autoregressive generation process is on the right track! Let me clarify a few things and address your questions.
Understanding Cache in Decoder Models
Yes, the cache in transformer-based decoder models (like the Mamba2 backbone you’re using) stores the internal states (key-value pairs) of the attention mechanism for previous sequence elements. This allows the model to avoid recomputing attention over past tokens, significantly improving efficiency during autoregressive decoding.
Addressing Your Code
Your example code is conceptually correct. You’re:
Generating encoder embeddings from the input.
Using those embeddings to initialize the decoder and generate the first prediction.
Iteratively decoding the sequence autoregressively while utilizing the cache to speed up computations.
Here’s the critical part of your question about cache_position:
cache_position
The cache_position is typically determined automatically by the decoder if you provide the sequence length correctly. However, if you want to explicitly manage it, cache_position usually represents the position of the current token relative to the cached context.
In practice:
During the first forward pass, cache_position is initialized (usually starting at 0).
For subsequent passes, the position increments by 1 for each token generated.
You can calculate it manually as the length of the current sequence if the decoder doesn’t handle this internally:
cache_position = len(sequence)
Suggested Code Update
Here’s how your code might look with adjustments for clarity:
def forward(input):
encoder_embeddings = encoder(input)
# Initial prediction
first_pred = decoder(
inputs_embeds=encoder_embeddings,
cache_params=None, # No cache for the first prediction
use_cache=True
)
sequence = [first_pred.last_hidden_state]
cache_params = first_pred.cache_params
cache_position = len(sequence) # Start cache position at 1
# Autoregressive decoding
while len(sequence) < max_sequence_length:
next_pred = decoder(
inputs_embeds=sequence[-1],
cache_params=cache_params,
use_cache=True,
cache_position=cache_position
)
sequence.append(next_pred.last_hidden_state)
cache_params = next_pred.cache_params
cache_position += 1 # Increment cache position
Key Takeaways
The cache stores the past key-value pairs for the attention mechanism, enabling efficient autoregressive generation.
cache_position can often be inferred from the sequence length but is sometimes managed internally by the model.
Make sure the decoder model you are using (with the Mamba2 backbone) supports explicit cache and cache position handling. Check the library’s documentation or implementation details for confirmation.
First of all thank you for the quick reply. I do have a followup:
I simply use the class Mamba2Model(Mamba2PreTrainedModel): as a decoder, as far as I see there is no automatic cache_position managment.
In the forward method, we have:
if use_cache:
if cache_params is None:
cache_params = Mamba2Cache(
self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
)
cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
elif cache_position is None:
# cases when we do manual forward instead of using `model.generate` which will initiate
# `cache_position` and makes sure it is not None, throw error here instead of doing some
# hack to conjecture the current cache position
raise ValueError(
"You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
"you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
"be initialized for you automatically"
)
else:
cache_params = None
But the cache_positions are never returned. Only the cache_params. Am I missing something here?
but the class Mamba2Model(Mamba2PreTrainedModel)
initializes the cache position as: cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
And the cache postion is required to be Optional[torch.LongTensor], therefore I am wondering why ``cache_position = len(sequence) should be sufficient