Why i can't use or can't pass past_key_values = DynamicCache() into Llama 3 model

from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
import torch

model_id = “meta-llama/Meta-Llama-3.1-8B-Instruct”

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map=“auto”,
)

messages = [
{“role”: “system”, “content”: “You are a helpful assistant”},
{“role”: “user”, “content”: “tell me about twice girl group! under 20 word”},
]

input_ids = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors=“pt”
).to(model.device)

terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids(“<|eot_id|>”)
]

Initialize the DynamicCache

past_key_values = DynamicCache()

Generate with DynamicCache

outputs = model.generate(
input_ids,
max_new_tokens=20000,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
past_key_values=past_key_values, # Pass the cache
use_cache=True, # Enable caching
)

response = outputs[0][input_ids.shape[-1]:]
tokenized_text = tokenizer.decode(response, skip_special_tokens=True)
print(tokenized_text)

Loading checkpoint shards: 100%

4/4 [01:04<00:00, 14.06s/it]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input’s attention_mask to obtain reliable results. Setting pad_token_id to eos_token_id:128009 for open-end generation. The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input’s attention_mask to obtain reliable results.

Twice is a 9-member South Korean girl group formed by JYP Entertainment, known for their energetic and catchy songs.

I’m getting error in here

If you want to continue the conversation, you can reuse the cache:

follow_up_message = [{“role”: “user”, “content”: “Tell me more about their music. under 20 word”}]
follow_up_input_ids = tokenizer.apply_chat_template(
messages + follow_up_message,
add_generation_prompt=True,
return_tensors=“pt”
).to(model.device)

follow_up_outputs = model.generate(
follow_up_input_ids,
max_new_tokens=200,
eos_token_id=terminators,
do_sample=True,
temperature=0.6,
top_p=0.9,
past_key_values=past_key_values, # Reuse the cache
use_cache=True,
)

follow_up_response = follow_up_outputs[0][follow_up_input_ids.shape[-1]:]
follow_up_text = tokenizer.decode(follow_up_response, skip_special_tokens=True)
print(“\nFollow-up response:”)
print(follow_up_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input’s attention_mask to obtain reliable results.
Setting pad_token_id to eos_token_id:128009 for open-end generation.

RuntimeError Traceback (most recent call last)
Cell In[2], line 9
2 follow_up_message = [{“role”: “user”, “content”: “Tell me more about their music. under 20 word”}]
3 follow_up_input_ids = tokenizer.apply_chat_template(
4 messages + follow_up_message,
5 add_generation_prompt=True,
6 return_tensors=“pt”
7 ).to(model.device)
----> 9 follow_up_outputs = model.generate(
10 follow_up_input_ids,
11 max_new_tokens=200,
12 eos_token_id=terminators,
13 do_sample=True,
14 temperature=0.6,
15 top_p=0.9,
16 past_key_values=past_key_values, # Reuse the cache
17 use_cache=True,
18 )
20 follow_up_response = follow_up_outputs[0][follow_up_input_ids.shape[-1]:]
21 follow_up_text = tokenizer.decode(follow_up_response, skip_special_tokens=True)

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator..decorate_context(*args, **kwargs)
113 @functools.wraps(func)
114 def decorate_context(*args, **kwargs):
115 with ctx_factory():
→ 116 return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2024, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
2016 input_ids, model_kwargs = self._expand_inputs_for_generation(
2017 input_ids=input_ids,
2018 expand_size=generation_config.num_return_sequences,
2019 is_encoder_decoder=self.config.is_encoder_decoder,
2020 **model_kwargs,
2021 )
2023 # 13. run sample (it degenerates to greedy search when generation_config.do_sample=False)
→ 2024 result = self._sample(
2025 input_ids,
2026 logits_processor=prepared_logits_processor,
2027 logits_warper=prepared_logits_warper,
2028 stopping_criteria=prepared_stopping_criteria,
2029 generation_config=generation_config,
2030 synced_gpus=synced_gpus,
2031 streamer=streamer,
2032 **model_kwargs,
2033 )
2035 elif generation_mode in (GenerationMode.BEAM_SAMPLE, GenerationMode.BEAM_SEARCH):
2036 # 11. prepare logits warper
2037 prepared_logits_warper = (
2038 self._get_logits_warper(generation_config, device=input_ids.device)
2039 if generation_config.do_sample
2040 else None
2041 )

File /opt/conda/lib/python3.10/site-packages/transformers/generation/utils.py:2982, in GenerationMixin._sample(self, input_ids, logits_processor, stopping_criteria, generation_config, synced_gpus, streamer, logits_warper, **model_kwargs)
2979 model_inputs.update({“output_hidden_states”: output_hidden_states} if output_hidden_states else {})
2981 # forward pass to get next token
→ 2982 outputs = self(**model_inputs, return_dict=True)
2984 if synced_gpus and this_peer_finished:
2985 continue # don’t waste resources running the code we don’t need

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
→ 1553 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don’t have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
→ 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1189, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
1186 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1188 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
→ 1189 outputs = self.model(
1190 input_ids=input_ids,
1191 attention_mask=attention_mask,
1192 position_ids=position_ids,
1193 past_key_values=past_key_values,
1194 inputs_embeds=inputs_embeds,
1195 use_cache=use_cache,
1196 output_attentions=output_attentions,
1197 output_hidden_states=output_hidden_states,
1198 return_dict=return_dict,
1199 cache_position=cache_position,
1200 )
1202 hidden_states = outputs[0]
1203 if self.config.pretraining_tp > 1:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
→ 1553 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don’t have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1001, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict, cache_position)
989 layer_outputs = self._gradient_checkpointing_func(
990 decoder_layer.call,
991 hidden_states,
(…)
998 position_embeddings,
999 )
1000 else:
→ 1001 layer_outputs = decoder_layer(
1002 hidden_states,
1003 attention_mask=causal_mask,
1004 position_ids=position_ids,
1005 past_key_value=past_key_values,
1006 output_attentions=output_attentions,
1007 use_cache=use_cache,
1008 cache_position=cache_position,
1009 position_embeddings=position_embeddings,
1010 )
1012 hidden_states = layer_outputs[0]
1014 if use_cache:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
→ 1553 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don’t have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
→ 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:734, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
731 hidden_states = self.input_layernorm(hidden_states)
733 # Self Attention
→ 734 hidden_states, self_attn_weights, present_key_value = self.self_attn(
735 hidden_states=hidden_states,
736 attention_mask=attention_mask,
737 position_ids=position_ids,
738 past_key_value=past_key_value,
739 output_attentions=output_attentions,
740 use_cache=use_cache,
741 cache_position=cache_position,
742 position_embeddings=position_embeddings,
743 **kwargs,
744 )
745 hidden_states = residual + hidden_states
747 # Fully Connected

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
1551 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1552 else:
→ 1553 return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
1557 # If we don’t have any hooks, we want to skip the rest of the logic in
1558 # this function, and just call forward.
1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1560 or _global_backward_pre_hooks or _global_backward_hooks
1561 or _global_forward_hooks or _global_forward_pre_hooks):
→ 1562 return forward_call(*args, **kwargs)
1564 try:
1565 result = None

File /opt/conda/lib/python3.10/site-packages/accelerate/hooks.py:170, in add_hook_to_module..new_forward(module, *args, **kwargs)
168 output = module._old_forward(*args, **kwargs)
169 else:
→ 170 output = module._old_forward(*args, **kwargs)
171 return module._hf_hook.post_forward(module, output)

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:635, in LlamaSdpaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache, cache_position, position_embeddings, **kwargs)
633 else:
634 cos, sin = position_embeddings
→ 635 query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
637 if past_key_value is not None:
638 # sin and cos are specific to RoPE models; cache_position needed for the static cache
639 cache_kwargs = {“sin”: sin, “cos”: cos, “cache_position”: cache_position}

File /opt/conda/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:275, in apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim)
273 cos = cos.unsqueeze(unsqueeze_dim)
274 sin = sin.unsqueeze(unsqueeze_dim)
→ 275 q_embed = (q * cos) + (rotate_half(q) * sin)
276 k_embed = (k * cos) + (rotate_half(k) * sin)
277 return q_embed, k_embed

RuntimeError: The size of tensor a (0) must match the size of tensor b (68) at non-singleton dimension 2

From the code it looks like the generated response isn’t being added to messages, that can cause shape mismatch.