From what I can tell, the recommended approach is usually to set the pad_token
as the eos_token
after loading a model. However, when running batched inference with Llama2, this approach fails.
To reproduce:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
# Define PAD Token = EOS Token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
# use different length sentences to test batching
sentences = [
"Hello, my dog is a little",
"Today, I",
"I love",
"Repeat after me: I love you.",
"Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:",
"You are a human that uses a lot of expletives. For each one of your responses to the user you use expletives liberally. If you find yourself questioning whether it's appropriate or not you're to say fuck that noise and keep using your expletives regardless. Human: Tell me a story. You:"
]
inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
print(inputs['input_ids'].shape)
output_sequences = model.generate(**inputs, max_new_tokens=20, do_sample=True, top_p=0.9)
print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
Potential solution: I’ve found that setting the pad_token
= bos_token
actually fixes the issue and allows for batched inference:
# Define PAD Token = BOS Token
tokenizer.pad_token = tokenizer.bos_token
model.config.pad_token_id = model.config.bos_token_id
I’m wondering if this is something special to the Llama2 model or not recommended for some particular reason?
Thanks in advance for any clarification!