I want to use meta-llama/Llama-2-70b-chat-hf in my notebook, but I am running into some problems.
I am using Transformers version 4.30.2.
Code:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-chat-hf")
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-chat-hf", device_map="auto", torch_dtype=torch.bfloat16
)
prompt = """\
<s>[INST] <<SYS>>
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
<</SYS>>
There's a llama in my garden 😱 What should I do? [/INST]
"""
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(inputs.input_ids.to("cuda"), max_new_tokens=1000)
print(
tokenizer.batch_decode(
outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False
)[0]
)
Error:
Exception has occurred: RuntimeError
shape '[1, 161, 64, 128]' is invalid for input of size 164864
File "/ceph/hpc/home/dstepec/NLP/generative-code-gen/llama_gen.py", line 21, in <module>
outputs = model.generate(inputs.input_ids.to("cuda"), max_new_tokens=1000)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: shape '[1, 161, 64, 128]' is invalid for input of size 164864
Stack trace:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[6], line 1
----> 1 outputs = model.generate(inputs.input_ids.to("cuda"), max_new_tokens=1000)
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/generation/utils.py:1522, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, **kwargs)
1516 raise ValueError(
1517 "num_return_sequences has to be 1 when doing greedy search, "
1518 f"but is {generation_config.num_return_sequences}."
1519 )
1521 # 11. run greedy search
-> 1522 return self.greedy_search(
1523 input_ids,
1524 logits_processor=logits_processor,
1525 stopping_criteria=stopping_criteria,
1526 pad_token_id=generation_config.pad_token_id,
1527 eos_token_id=generation_config.eos_token_id,
1528 output_scores=generation_config.output_scores,
1529 return_dict_in_generate=generation_config.return_dict_in_generate,
1530 synced_gpus=synced_gpus,
1531 streamer=streamer,
1532 **model_kwargs,
1533 )
1535 elif is_contrastive_search_gen_mode:
1536 if generation_config.num_return_sequences > 1:
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/generation/utils.py:2339, in GenerationMixin.greedy_search(self, input_ids, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, streamer, **model_kwargs)
2336 model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
2338 # forward pass to get next token
-> 2339 outputs = self(
2340 **model_inputs,
2341 return_dict=True,
2342 output_attentions=output_attentions,
2343 output_hidden_states=output_hidden_states,
2344 )
2346 if synced_gpus and this_peer_finished:
2347 continue # don't waste resources running the code we don't need
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:688, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
685 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
687 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 688 outputs = self.model(
689 input_ids=input_ids,
690 attention_mask=attention_mask,
691 position_ids=position_ids,
692 past_key_values=past_key_values,
693 inputs_embeds=inputs_embeds,
694 use_cache=use_cache,
695 output_attentions=output_attentions,
696 output_hidden_states=output_hidden_states,
697 return_dict=return_dict,
698 )
700 hidden_states = outputs[0]
701 logits = self.lm_head(hidden_states)
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:578, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict)
570 layer_outputs = torch.utils.checkpoint.checkpoint(
571 create_custom_forward(decoder_layer),
572 hidden_states,
(...)
575 None,
576 )
577 else:
--> 578 layer_outputs = decoder_layer(
579 hidden_states,
580 attention_mask=attention_mask,
581 position_ids=position_ids,
582 past_key_value=past_key_value,
583 output_attentions=output_attentions,
584 use_cache=use_cache,
585 )
587 hidden_states = layer_outputs[0]
589 if use_cache:
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:292, in LlamaDecoderLayer.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
289 hidden_states = self.input_layernorm(hidden_states)
291 # Self Attention
--> 292 hidden_states, self_attn_weights, present_key_value = self.self_attn(
293 hidden_states=hidden_states,
294 attention_mask=attention_mask,
295 position_ids=position_ids,
296 past_key_value=past_key_value,
297 output_attentions=output_attentions,
298 use_cache=use_cache,
299 )
300 hidden_states = residual + hidden_states
302 # Fully Connected
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/accelerate/hooks.py:165, in add_hook_to_module.<locals>.new_forward(*args, **kwargs)
163 output = old_forward(*args, **kwargs)
164 else:
--> 165 output = old_forward(*args, **kwargs)
166 return module._hf_hook.post_forward(module, output)
File ~/miniconda3/envs/generative_env/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:195, in LlamaAttention.forward(self, hidden_states, attention_mask, position_ids, past_key_value, output_attentions, use_cache)
192 bsz, q_len, _ = hidden_states.size()
194 query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
--> 195 key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
196 value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
198 kv_seq_len = key_states.shape[-2]
RuntimeError: shape '[1, 161, 64, 128]' is invalid for input of size 164864
The same is not happening with this model: meta-llama/Llama-2-713b-chat-hf.
What am I doing wrong?