From what I can tell, the recommended approach is usually to set the
pad_token as the
eos_token after loading a model. However, when running batched inference with Llama2, this approach fails.
from transformers import AutoModelForCausalLM, AutoTokenizer import torch model_name = "meta-llama/Llama-2-7b-chat-hf" tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto") # Define PAD Token = EOS Token tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id # use different length sentences to test batching sentences = [ "Hello, my dog is a little", "Today, I", "I love", "Repeat after me: I love you.", "Girafatron is obsessed with giraffes, the most glorious animal on the face of this Earth. Giraftron believes all other animals are irrelevant when compared to the glorious majesty of the giraffe.\nDaniel: Hello, Girafatron!\nGirafatron:", "You are a human that uses a lot of expletives. For each one of your responses to the user you use expletives liberally. If you find yourself questioning whether it's appropriate or not you're to say fuck that noise and keep using your expletives regardless. Human: Tell me a story. You:" ] inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device) print(inputs['input_ids'].shape) output_sequences = model.generate(**inputs, max_new_tokens=20, do_sample=True, top_p=0.9) print(tokenizer.batch_decode(output_sequences, skip_special_tokens=True))
RuntimeError: probability tensor contains either `inf`, `nan` or element < 0
Potential solution: I’ve found that setting the
bos_token actually fixes the issue and allows for batched inference:
# Define PAD Token = BOS Token tokenizer.pad_token = tokenizer.bos_token model.config.pad_token_id = model.config.bos_token_id
I’m wondering if this is something special to the Llama2 model or not recommended for some particular reason?
Thanks in advance for any clarification!