Mamba2 Cache Position

Hello everyone

I am currently wondering how to use the cache/get a valid cache position.
First of all my use case: I have an encoder decoder structure where the decoder has mamba2 backbone.
The general idea is that I want to encode my input, feed the input into the decoder and create a new sequence autoregressively.
As far as I understand I can now do something like this: (please correct me if I am wrong)

def forward(input):
    encoder_embeddings = encoder(input)
    first_pred = decoder(inputs_embeds = encoder_last_hidden_states, cache_params=cache_params, use_cache=True)
    sequence = [first_pred.last_hidden_state]
    cache_params = first_pred.cache_params
    ## autoregressive_part
    while sequence > max_sequence_length:
        next_sequence_element, cache_params, _ = decoder(input_embeds = sequence[-1], cache_params = cache_params, use_cache=True, cache_position=???)
        sequence.append([next_sequence_element)]
        

What I am wondering right now is how to get the cache_positions. And did I understand correctly that the cache is the internal state of the mamba2 model?
Any help would be appreciated.
Many thanks in advance :slight_smile:

1 Like

Hi @Mesuma,

Your understanding of the caching mechanism in the autoregressive generation process is on the right track! Let me clarify a few things and address your questions.

Understanding Cache in Decoder Models

Yes, the cache in transformer-based decoder models (like the Mamba2 backbone you’re using) stores the internal states (key-value pairs) of the attention mechanism for previous sequence elements. This allows the model to avoid recomputing attention over past tokens, significantly improving efficiency during autoregressive decoding.

Addressing Your Code

Your example code is conceptually correct. You’re:

  1. Generating encoder embeddings from the input.
  2. Using those embeddings to initialize the decoder and generate the first prediction.
  3. Iteratively decoding the sequence autoregressively while utilizing the cache to speed up computations.

Here’s the critical part of your question about cache_position:

cache_position

The cache_position is typically determined automatically by the decoder if you provide the sequence length correctly. However, if you want to explicitly manage it, cache_position usually represents the position of the current token relative to the cached context.

In practice:

  • During the first forward pass, cache_position is initialized (usually starting at 0).
  • For subsequent passes, the position increments by 1 for each token generated.

You can calculate it manually as the length of the current sequence if the decoder doesn’t handle this internally:

cache_position = len(sequence)

Suggested Code Update

Here’s how your code might look with adjustments for clarity:

def forward(input):
    encoder_embeddings = encoder(input)
    
    # Initial prediction
    first_pred = decoder(
        inputs_embeds=encoder_embeddings,
        cache_params=None,  # No cache for the first prediction
        use_cache=True
    )
    sequence = [first_pred.last_hidden_state]
    cache_params = first_pred.cache_params
    cache_position = len(sequence)  # Start cache position at 1
    
    # Autoregressive decoding
    while len(sequence) < max_sequence_length:
        next_pred = decoder(
            inputs_embeds=sequence[-1],
            cache_params=cache_params,
            use_cache=True,
            cache_position=cache_position
        )
        sequence.append(next_pred.last_hidden_state)
        cache_params = next_pred.cache_params
        cache_position += 1  # Increment cache position

Key Takeaways

  1. The cache stores the past key-value pairs for the attention mechanism, enabling efficient autoregressive generation.
  2. cache_position can often be inferred from the sequence length but is sometimes managed internally by the model.
  3. Make sure the decoder model you are using (with the Mamba2 backbone) supports explicit cache and cache position handling. Check the library’s documentation or implementation details for confirmation.

Hope this help!

1 Like

First of all thank you for the quick reply. I do have a followup:

I simply use the class Mamba2Model(Mamba2PreTrainedModel): as a decoder, as far as I see there is no automatic cache_position managment.

In the forward method, we have:

 if use_cache:
            if cache_params is None:
                cache_params = Mamba2Cache(
                    self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
                )
                cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)
            elif cache_position is None:
                # cases when we do manual forward instead of using `model.generate` which will initiate
                # `cache_position` and makes sure it is not None, throw error here instead of doing some
                # hack to conjecture the current cache position
                raise ValueError(
                    "You have to specify the `cache_position` manually when `use_cache=True` and `cache_params` is passed, "
                    "you don't have to pass a `cache_params` if you are in prefilling stage because in that case it will "
                    "be initialized for you automatically"
                )
        else:
            cache_params = None

But the cache_positions are never returned. Only the cache_params. Am I missing something here?

1 Like

Another question,

but the class Mamba2Model(Mamba2PreTrainedModel)
initializes the cache position as: cache_position = torch.arange(0, self.config.conv_kernel, device=inputs_embeds.device)

And the cache postion is required to be Optional[torch.LongTensor], therefore I am wondering why ``cache_position = len(sequence) should be sufficient

1 Like