LLama 70B not working

I want to use meta-llama/Llama-2-70b-chat-hf in my notebook, but I am running into some problems.

I am using Transformers version 4.30.2.

Code:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

if __name__ == "__main__":
    tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-chat-hf")
    model = AutoModelForCausalLM.from_pretrained(
        "meta-llama/Llama-2-70b-chat-hf", device_map="auto", torch_dtype=torch.bfloat16
    )

    prompt = """\
        <s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe.  Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.

If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>

There's a llama in my garden 😱 What should I do? [/INST]
    """
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(inputs.input_ids.to("cuda"), max_new_tokens=1000)
    print(
        tokenizer.batch_decode(
            outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
        )[0]
    )

Error:

Exception has occurred: RuntimeError
shape '[1, 161, 64, 128]' is invalid for input of size 164864
  File "/ceph/hpc/home/dstepec/NLP/generative-code-gen/llama_gen.py", line 21, in <module>
    outputs = model.generate(inputs.input_ids.to("cuda"), max_new_tokens=1000)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 161, 64, 128]' is invalid for input of size 164864

Stack trace:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[6], line 1
----> 1 outputs = model.generate(inputs.input_ids.to("cuda"), max_new_tokens=1000)

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/generation/utils.py:1522, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
   1516         raise ValueError(
   1517             "num_return_sequences has to be 1 when doing greedy search, "
   1518             f"but is {generation_config.num_return_sequences}."
   1519         )
   1521     # 11. run greedy search
-> 1522     return self.greedy_search(
   1523         input_ids,
   1524         logits_processor=logits_processor,
   1525         stopping_criteria=stopping_criteria,
   1526         pad_token_id=generation_config.pad_token_id,
   1527         eos_token_id=generation_config.eos_token_id,
   1528         output_scores=generation_config.output_scores,
   1529         return_dict_in_generate=generation_config.return_dict_in_generate,
   1530         synced_gpus=synced_gpus,
   1531         streamer=streamer,
   1532         **model_kwargs,
   1533     )
   1535 elif is_contrastive_search_gen_mode:
   1536     if generation_config.num_return_sequences > 1:

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/generation/utils.py:2339, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
   2336 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
   2338 # forward pass to get next token
-> 2339 outputs = self(
   2340     **model_inputs,
   2341     return_dict=True,
   2342     output_attentions=output_attentions,
   2343     output_hidden_states=output_hidden_states,
   2344 )
   2346 if synced_gpus and this_peer_finished:
   2347     continue  # don't waste resources running the code we don't need

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:688, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
    685 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    687 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 688 outputs = self.model(
    689     input_ids=input_ids,
    690     attention_mask=attention_mask,
    691     position_ids=position_ids,
    692     past_key_values=past_key_values,
    693     inputs_embeds=inputs_embeds,
    694     use_cache=use_cache,
    695     output_attentions=output_attentions,
    696     output_hidden_states=output_hidden_states,
    697     return_dict=return_dict,
    698 )
    700 hidden_states = outputs[0]
    701 logits = self.lm_head(hidden_states)

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:578, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
    570     layer_outputs = torch.utils.checkpoint.checkpoint(
    571         create_custom_forward(decoder_layer),
    572         hidden_states,
   (...)
    575         None,
    576     )
    577 else:
--> 578     layer_outputs = decoder_layer(
    579         hidden_states,
    580         attention_mask=attention_mask,
    581         position_ids=position_ids,
    582         past_key_value=past_key_value,
    583         output_attentions=output_attentions,
    584         use_cache=use_cache,
    585     )
    587 hidden_states = layer_outputs[0]
    589 if use_cache:

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:292, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    289 hidden_states = self.input_layernorm(hidden_states)
    291 # Self Attention
--> 292 hidden_states, self_attn_weights, present_key_value = self.self_attn(
    293     hidden_states=hidden_states,
    294     attention_mask=attention_mask,
    295     position_ids=position_ids,
    296     past_key_value=past_key_value,
    297     output_attentions=output_attentions,
    298     use_cache=use_cache,
    299 )
    300 hidden_states = residual + hidden_states
    302 # Fully Connected

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
    163         output = old_forward(*args, **kwargs)
    164 else:
--> 165     output = old_forward(*args, **kwargs)
    166 return module._hf_hook.post_forward(module, output)

File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:195, in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
    192 bsz, q_len, _ = hidden_states.size()
    194 query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
--> 195 key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    196 value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
    198 kv_seq_len = key_states.shape[-2]

RuntimeError: shape '[1, 161, 64, 128]' is invalid for input of size 164864

The same is not happening with this model: meta-llama/Llama-2-713b-chat-hf.

What am I doing wrong?

Actually upgrading to transformers 4.31 solved the problem.