Hey, i am trying to perform batch inference using oasst-sft-7-llamba-30b (open assistent model but i don’t think it is really related to the model’s type) and i cannot get it to work with batch>1 if i set the batch size to more than 1 it just output low quality text (compre to batch=1) here is the code that i use:
import bitsandbytes as bnb
load_8bit: bool = True
base_model: str = "/storage/oasst-sft-7-llama-30b"
prompt_template: str = "oasst"
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
generator = pipeline('text-generation', max_length=512, model=model, tokenizer=tokenizer, device=0,
batch_size=4, num_beams=1, top_k=40, top_p=0.1, temperature=0.0)
for outputs in tqdm(generator(KeyDataset(dataset, "text" ))):
print([out["generated_text"] for out in outputs])