from transformers import StoppingCriteria
# Stop generation after all batch elements have generated an EOS token.
# Stores the index of the first generated EOS token for each batch element in "self.eos_index,"
# which can be used to slice off whatever extra junk was generated after it.
# Note: This is a stateful object. A new instance should be created for each call to generate().
class EosStoppingCriteria(StoppingCriteria):
def __init__(self, tokenizer):
super().__init__()
self.eos_token = tokenizer.eos_token_id
self.done = None
self.eos_index = None
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
batch_size, seq_len = input_ids.shape
# Lazy construct a bool state for each batch element
if self.done == None:
self.done = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
self.eos_index = torch.zeros(batch_size, dtype=torch.int, device=input_ids.device)
# Get last token ids in batch
last_ids = input_ids[:, -1]
# Create mask of where the last token is EOS
done_update = self.done | (last_ids == self.eos_token)
# Store the indices where we stopped at for each sequence in the batch.
# Where the 'done' state has changed, store the seq_len (last index), else 0
eos_index_update = torch.where(done_update ^ self.done, torch.full_like(self.eos_index, seq_len), 0)
# Add the update to the indices
self.eos_index += eos_index_update
# Update the done flags
self.done = done_update
# Return True, if all done.
return self.done.all()
# Apply model's chat template
def generate_instruction_prompt(tokenizer, system_msg, instruction):
messages = []
if system_msg is not None:
messages.append({ "role": "system", "content": system_msg })
messages.append({ "role": "user", "content": instruction })
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
return prompt
# Given a single system-msg and a list of instructions, generate
# a prompt for each instruciton, tokenize the instructions, getting their lengths,
# then coallate the tokenized instructions into a batch, adding padding.
# returns: a batch of tokenized instructions, list of lenghts (tokens) of each instruction.
def tokenize_instruction_batch(tokenizer, system_msg, instructions):
prompts = []
for instruction in instructions:
prompt = generate_instruction_prompt(tokenizer, system_msg, instruction)
prompts.append(prompt)
encoded_prompts = tokenizer(
prompts,
truncation=True,
return_length=True,
add_special_tokens=False,
)
input_ids = encoded_prompts["input_ids"]
lengths = encoded_prompts["length"]
tokenizer_outputs = tokenizer.pad(
encoded_prompts,
padding="longest",
return_tensors='pt',
)
return tokenizer_outputs, lengths
# Given a single system prompt and a batch of instrucitons, batch generate outpus
# This will identify where the start and end of each generation, slicing them at these
# points, then decode and print the outputs
def batch_instruct_generate(
model,
tokenizer,
system_msg,
instructions,
max_new_tokens=512,
generation_config=None,
show_ids=False,
skip_special_tokens=False,
device="cuda:0"
):
model.to(device)
tokenizer_outputs, lengths = tokenize_instruction_batch(tokenizer, system_msg, instructions)
input_ids = tokenizer_outputs["input_ids"].to(device)
padded_len = input_ids.size(1)
if show_ids:
print("input_ids\n\n", input_ids)
gen_texts = tokenizer.batch_decode(
input_ids,
skip_special_tokens=False
)
print("Decoded Prompts\n")
for i, text in enumerate(gen_texts):
print(f"{i:-^120}")
print(text)
stopping_criteria = EosStoppingCriteria(tokenizer)
outputs = model.generate(
input_ids,
generation_config=generation_config,
max_new_tokens=max_new_tokens,
stopping_criteria=[stopping_criteria],
)
if show_ids:
print("output_ids\n\n", outputs)
print("Generated Text")
batch_size, seq_len = outputs.shape
new_tokens = seq_len - padded_len
for i in range(batch_size):
# Compute the index of the first token.
start_index = padded_len - lengths[i]
# Split each sequence and slice end at captured eos_index
sequence = outputs[i][start_index:stopping_criteria.eos_index[i]]
# Decode the output
# Note: We could also collect these into a list and batch decode them.
text = tokenizer.decode(
sequence,
skip_special_tokens=skip_special_tokens
)
print(f"{i:-^120}")
print(text)
# This model does not support a system message.
system_msg = None
instruction = "Repeat the input, but speak like a pirate.\n\n"
batch_instruct_generate(
model,
tokenizer,
# Model does not support system msg, so prepend it to t
system_msg=system_msg,
instructions=[
instruction + "Let's sail to Barbados",
instruction + "We will be rich!"
],
generation_config=generation_config,
max_new_tokens=256,
show_ids=True,
skip_special_tokens=False
)
Example output with “mistralai_Mistral-7B-Instruct-v0.2”
input_ids
tensor([[ 1, 733, 16289, 28793, 1298, 15882, 272, 2787, 28725, 562,
4085, 737, 264, 17368, 380, 28723, 13, 13, 8779, 28742,
28713, 12432, 298, 25223, 3482, 733, 28748, 16289, 28793],
[ 2, 2, 1, 733, 16289, 28793, 1298, 15882, 272, 2787,
28725, 562, 4085, 737, 264, 17368, 380, 28723, 13, 13,
2324, 622, 347, 6708, 28808, 733, 28748, 16289, 28793]],
device='cuda:0')
Decoded Prompts
-----------------------------------------------------------0------------------------------------------------------------
<s> [INST] Repeat the input, but speak like a pirate.
Let's sail to Barbados [/INST]
-----------------------------------------------------------1------------------------------------------------------------
</s></s><s> [INST] Repeat the input, but speak like a pirate.
We will be rich! [/INST]
output_ids
tensor([[ 1, 733, 16289, 28793, 1298, 15882, 272, 2787, 28725, 562,
4085, 737, 264, 17368, 380, 28723, 13, 13, 8779, 28742,
28713, 12432, 298, 25223, 3482, 733, 28748, 16289, 28793, 20037,
15095, 28724, 28725, 1346, 592, 808, 12432, 396, 28742, 22689,
1167, 15507, 298, 272, 4433, 8919, 302, 25223, 3482, 28808,
627, 2654, 28808, 2],
[ 2, 2, 1, 733, 16289, 28793, 1298, 15882, 272, 2787,
28725, 562, 4085, 737, 264, 17368, 380, 28723, 13, 13,
2324, 622, 347, 6708, 28808, 733, 28748, 16289, 28793, 20037,
15095, 28724, 28725, 478, 28742, 584, 347, 461, 11394, 1162,
6708, 28725, 337, 2654, 28808, 2, 28705, 243, 162, 146,
183, 29274, 31840, 29096]], device='cuda:0')
Generated Text
-----------------------------------------------------------0------------------------------------------------------------
<s> [INST] Repeat the input, but speak like a pirate.
Let's sail to Barbados [/INST] Arr matey, let us set sail an' navigate these waters to the fine island of Barbados! Yarr!</s>
-----------------------------------------------------------1------------------------------------------------------------
<s> [INST] Repeat the input, but speak like a pirate.
We will be rich! [/INST] Arr matey, we'll be jolly well rich, yarr!</s>