Parallelize Mistral/ llama2 output

As given in the following examples I wish to extract the fruits in a given sentence. I know how to do this with one sentence/ description, but don’t know how to get this working in parallel. In summary, I am trying to have a batch size greater than 1 for generation.

This is what I tried. Please let me know how I am supposed to do this correctly. Some comments

  • Not sure if this is correct: tokenizer.pad_token = tokenizer.eos_token
  • Given that I am padding with eos_token, does this mess up the generation since I am passing an eos to the shorter sentence?
from transformers import AutoModelForCausalLM, AutoTokenizer

device = "cuda" # the device to load the model onto

model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", load_in_4bit=True, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")

prompt = "Describe the fruits in this sentence"
unique_descriptions = ["apples, oranges", "pineapple, apple, pen"]

def get_message(description: str, prompt: str):
    return [
            {"role": "user", "content": "You are a helpful assistant."},
            {"role": "assistant", "content": prompt},
            {"role": "user", "content": description}
          ]

messages = [
    tokenizer.apply_chat_template(
        get_message(description, prompt), 
        tokenize=False, 
        add_generation_prompt=True
    )
    for description in unique_descriptions
]

tokenizer.pad_token = tokenizer.eos_token
tokenized_text = tokenizer(
    messages, 
    truncation=True, 
    padding=True,
    return_tensors="pt"
).to(device)

generated_ids = model.generate(**tokenized_text, max_new_tokens=1000, do_sample=True)
decoded = tokenizer.batch_decode(generated_ids)

Outputs

We see these weird paddings before the sentence which doesn’t make sense to me.

decoded[1]
>>> '<s><s> [INST] You are a helpful assistant. [/INST]Describe the fruits in this sentence</s> [INST] pineapple, apple, pen [/INST] The first two items in the sentence are fruits. A pineapple is a tropical plant with an edible fruit consisting of many berries that grow in a compact clusters. The fruit of a pineapple is sweet, juicy, and has a fibrous texture. An apple is a sweet, edible fruit that comes from a tree. It is round or oval in shape and has a red, green, or yellow skin, depending on the variety. The last item in the sentence, "pen," is not a fruit. It is a writing instrument with a removable or retractable nib or ballpoint for applying ink to a surface in order to write.</s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s></s>'


decoded[0]
>>> '</s><s><s> [INST] You are a helpful assistant. [/INST]Describe the fruits in this sentence</s> [INST] apples, oranges [/INST] Apples and oranges are both common fruits that are popular around the world. An apple is a round or oval-shaped fruit with a thick, usually red, green, or yellow skin, and a juicy, crisp white or yellow flesh. Apples are native to Central Asia and have been cultivated for thousands of years. They are rich in dietary fiber and vitamin C.\n\nOranges, on the other hand, are a citrus fruit with a bright orange rind and juicy, sweet flesh. They are native to Southeast Asia, and their juicy segments are enclosed in a relatively thin membrane. Oranges are packed with vitamin C and also contain other essential nutrients like folate, potassium, and dietary fiber.</s>'

Just wondering if you need to include “skip_special_tokens=True”. Something like:

decoded = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)