I have just switched from LLaMA 1 to Llama 2 (same hardware), and generating text takes 10x longer now. I’m guessing that there is something in my code that makes it this much slower. Can anyone point me to mistakes in the code below?
from transformers import LlamaForCausalLM, LlamaTokenizerFast
def llama_textgen(prompt, model, tokenizer, max_tokens=4):
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
inputs.input_ids,
max_new_tokens=max_tokens,
do_sample=False,
)
text_outputs = tokenizer.batch_decode(
outputs[:, inputs.input_ids.shape[-1] :],
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
return text_outputs
model_base_name = "meta-llama/Llama-2-70b-chat-hf"
model = LlamaForCausalLM.from_pretrained(
model_base_name,
cache_dir="/llama2_chat"
)
tokenizer = LlamaTokenizerFast.from_pretrained(model_base_name)
prompt = 'I liked "Breaking Bad" and "Band of Brothers". Do you have any recommendations of other shows I might like?\n'
response = llama_textgen(
prompt,
model=model,
tokenizer=tokenizer,
)