Here is my completed code for batch generation, with stopping after all sequences have generated EOS or reached max_new_tokens.
import torch
from transformers import StoppingCriteria
# Stop generation after all batch elements have generated an EOS token.
# Stores the index of the first generated EOS token for each batch element in "self.eos_index,"
# which can be used to slice off whatever extra junk was generated after it.
# Note: This is a stateful object. A new instance should be created for each call to generate().
class EosStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer):
super().__init__()
self.eos_token = tokenizer.eos_token_id
self.done = None
self.eos_index = None
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
batch_size, seq_len = input_ids.shape
# Lazy construct a bool state for each batch element
if self.done == None:
self.done = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
self.eos_index = torch.zeros(batch_size, dtype=torch.int, device=input_ids.device)
# Get last token ids in batch
last_ids = input_ids[:, -1]
# Create mask of where the last token is EOS
done_update = self.done | (last_ids == self.eos_token)
# Store the indices where we stopped at for each sequence in the batch.
# Where the 'done' state has changed, store the seq_len (last index), else 0
eos_index_update = torch.where(done_update ^ self.done, torch.full_like(self.eos_index, seq_len), 0)
# Add the update to the indices
self.eos_index += eos_index_update
# Update the done flags
self.done = done_update
# Return True, if all done.
return self.done.all()
# Apply model's chat template
def generate_instruction_prompt(tokenizer, instruction, system_msg=None):
messages = []
if system_msg is not None:
messages.append({ "role": "system", "content": system_msg })
messages.append({ "role": "user", "content": instruction })
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
def batch_generate(
model,
tokenizer,
prompts,
max_new_tokens=512,
generation_config=None,
skip_special_tokens=False,
max_length=None,
device='cpu'
):
model.to(device)
encoded_prompts = tokenizer(
prompts,
truncation=True,
return_length=True,
add_special_tokens=False,
max_length=max_length,
)
lengths = encoded_prompts["length"]
tokenizer_outputs = tokenizer.pad(
encoded_prompts,
padding="longest",
return_tensors='pt',
)
input_ids = tokenizer_outputs['input_ids'].to(device)
padded_len = input_ids.size(1)
stopping_criteria = EosStoppingCriteria(tokenizer)
outputs = model.generate(
input_ids,
generation_config=generation_config,
max_new_tokens=max_new_tokens,
stopping_criteria=[stopping_criteria],
)
batch_size, seq_len = outputs.shape
output_ids = []
for i in range(batch_size):
# Compute the index of the first token.
start_index = padded_len - lengths[i]
end_index = stopping_criteria.eos_index[i]
if end_index == 0:
end_index = seq_len
# Split each sequence and slice end at captured eos_index
output_ids.append(outputs[i][start_index:end_index])
output_texts = tokenizer.batch_decode(
output_ids,
skip_special_tokens=skip_special_tokens
)
return output_texts
Example usage:
from transformers import GenerationConfig
generation_config = GenerationConfig(
max_new_tokens=512, do_sample=True, top_k=20, top_p=0.9, temperature=0.7, repetition_penalty=1.15, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
)
instruction = "Repeat the input, but speak like a pirate.\n\n"
output_texts = batch_generate(
model,
tokenizer,
prompts=[
generate_instruction_prompt(tokenizer, instruction + "Let's sail to Barbados"),
generate_instruction_prompt(tokenizer, instruction + "We will be rich!"),
],
generation_config=generation_config,
max_new_tokens=256,
skip_special_tokens=False,
device="cuda:0"
)
for text in output_texts:
print("---")
print(text)