Batch inference using open source LLMs

Hey, i am trying to perform batch inference using oasst-sft-7-llamba-30b (open assistent model but i don’t think it is really related to the model’s type) and i cannot get it to work with batch>1 if i set the batch size to more than 1 it just output low quality text (compre to batch=1) here is the code that i use:

import bitsandbytes as bnb
load_8bit: bool = True
base_model: str = "/storage/oasst-sft-7-llama-30b"
prompt_template: str = "oasst"
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    load_in_8bit=load_8bit,
    torch_dtype=torch.float16,
    device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
generator = pipeline('text-generation', max_length=512,  model=model, tokenizer=tokenizer, device=0,
                     batch_size=4,  num_beams=1, top_k=40, top_p=0.1, temperature=0.0)
for outputs in tqdm(generator(KeyDataset(dataset, "text" ))):
    print([out["generated_text"] for out in outputs])
3 Likes

Hi @galprz !

Did you find a solution for this?