Implimentation of Stopping Criteria List

Here is my completed code for batch generation, with stopping after all sequences have generated EOS or reached max_new_tokens.

import torch
from transformers import StoppingCriteria

# Stop generation after all batch elements have generated an EOS token.
# Stores the index of the first generated EOS token for each batch element in "self.eos_index,"
# which can be used to slice off whatever extra junk was generated after it.
# Note: This is a stateful object. A new instance should be created for each call to generate().
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        super().__init__()
        self.eos_token = tokenizer.eos_token_id
        self.done = None
        self.eos_index = None

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        batch_size, seq_len = input_ids.shape
        
        # Lazy construct a bool state for each batch element
        if self.done == None:
            self.done = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
            self.eos_index = torch.zeros(batch_size, dtype=torch.int, device=input_ids.device)

        # Get last token ids in batch
        last_ids = input_ids[:, -1]

        # Create mask of where the last token is EOS
        done_update = self.done | (last_ids == self.eos_token)
        
        # Store the indices where we stopped at for each sequence in the batch.
        # Where the 'done' state has changed, store the seq_len (last index), else 0
        eos_index_update = torch.where(done_update ^ self.done, torch.full_like(self.eos_index, seq_len), 0)

        # Add the update to the indices
        self.eos_index += eos_index_update

        # Update the done flags
        self.done = done_update

        # Return True, if all done.
        return self.done.all()

# Apply model's chat template
def generate_instruction_prompt(tokenizer, instruction, system_msg=None):
    messages = []
    if system_msg is not None:
        messages.append({ "role": "system", "content": system_msg })
    messages.append({ "role": "user", "content": instruction })
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

def batch_generate(
    model,
    tokenizer,
    prompts,
    max_new_tokens=512,
    generation_config=None,
    skip_special_tokens=False,
    max_length=None,
    device='cpu'
):
    model.to(device)

    encoded_prompts = tokenizer(
        prompts,
        truncation=True,
        return_length=True,
        add_special_tokens=False,
        max_length=max_length,
    )
    
    lengths = encoded_prompts["length"]
    tokenizer_outputs = tokenizer.pad(
        encoded_prompts,
        padding="longest",
        return_tensors='pt',
    )
    
    input_ids = tokenizer_outputs['input_ids'].to(device)
    padded_len = input_ids.size(1)
    stopping_criteria = EosStoppingCriteria(tokenizer)
    
    outputs = model.generate(
        input_ids,
        generation_config=generation_config,
        max_new_tokens=max_new_tokens,
        stopping_criteria=[stopping_criteria],
    )
    
    batch_size, seq_len = outputs.shape

    output_ids = []
    
    for i in range(batch_size):
        # Compute the index of the first token.
        start_index = padded_len - lengths[i]
        end_index = stopping_criteria.eos_index[i]
        if end_index == 0:
            end_index = seq_len
        
        # Split each sequence and slice end at captured eos_index
        output_ids.append(outputs[i][start_index:end_index])

    output_texts = tokenizer.batch_decode(
        output_ids,
        skip_special_tokens=skip_special_tokens
    )

    return output_texts

Example usage:

from transformers import GenerationConfig

generation_config = GenerationConfig(
    max_new_tokens=512, do_sample=True, top_k=20, top_p=0.9, temperature=0.7, repetition_penalty=1.15, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
)

instruction = "Repeat the input, but speak like a pirate.\n\n"
output_texts = batch_generate(
    model,
    tokenizer,
    prompts=[
        generate_instruction_prompt(tokenizer, instruction + "Let's sail to Barbados"),
        generate_instruction_prompt(tokenizer, instruction + "We will be rich!"),
    ],
    generation_config=generation_config,
    max_new_tokens=256,
    skip_special_tokens=False,
    device="cuda:0"
)

for text in output_texts:
    print("---")
    print(text)