Implimentation of Stopping Criteria List

Dear HF,

Would someone please show me how to use the stopping criteria.

I would like to stop generation if certain words / phrases are generated e.g. “foo bar”, “moo bar foo”

The instructions seem to use the Bert tokeniser - to generate tokens of the stop sequence?

I am trying to implement this with the OPT model (13b) - would I still use the BERT tokeniser?

Would anyone be able to show an example of using this successfully?

3 Likes

Hi!

This is my first reply ever, so please don’t judge too strictly :slight_smile:

I was able to solve a similar task today (for GPT-2), so maybe my suggestion will help you.

So, first things first:

  1. Q: “The instructions seem to use the Bert tokeniser - to generate tokens of the stop sequence?”
    A: in order to stop the sequence the model should know the token that should be used for stopping. So, the tokenizer.encode is used for you to see what token/sequence of tokens will correspond to your word (e.g. “foo bar”) (e.g.:
stop_words_ids = [
    tokenizer.encode(stop_word, add_prefix_space = False) for stop_word in ["foo", "bar"]]

)

  1. Q: “I am trying to implement this with the OPT model (13b) - would I still use the BERT tokeniser?”
    A: No, you should use the tokenizer from your respective model to get the correct tokens. E.G. here are the token values that I’ve got for the words “foo, bar” using a tokenizer from my model which I’m currently training: [[21943], [5657]]. The tokens will be different for different models, that’s why you should use the tokenizer for your model e.g.:
tokenizer = 
AutoTokenizer.from_pretrained("facebook/opt-13b", use_fast=False))
  1. Q: “Would anyone be able to show an example of using this successfully?”

Here is what you should do once you know the token IDs to use for stopping:
a) import required modules: from transformers import StoppingCriteria, StoppingCriteriaList
b) subclass the StoppingCriteria class and add to it a new functionality:

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = []):
      StoppingCriteria.__init__(self), 

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops = []):
      self.stops = stops
      for i in range(len(stops)):
        self.stops = self.stops[i]

c) instantiate the class (and pass the tokens which you want to use for stopping as an argument):

stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = [[21943], [5657]])])

d) finally, pass stopping_criteria as an argument to model.generate:
model.generate(input_ids, do_sample=True, stopping_criteria=stopping_criteria)

I hope, it is helpful

3 Likes
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops = []):
      self.stops = stops
      for i in range(len(stops)):
        self.stops = self.stops[i]

What is this stopping criteria doing? It should return a True if a token in input_ids occurs in self.stops.

1 Like

Tried stopping on new lines using


    def __init__(self, stops = []):
      StoppingCriteria.__init__(self), 

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops = []):
      self.stops = stops
      for i in range(len(stops)):
        self.stops = self.stops[i]

stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = [[13], [198], [0], [30], [11]])])

With GPT2 pre-trained model but still getting the resulting tokens in output multiple times. Generation is using beam search as follows:

output = model.generate(input_ids_batch,
                        early_stopping=True, num_beams=5,
                        temperature=0.7,
                        top_p=0.8,
                        do_sample=True,
                        pad_token_id=50256,
                        stopping_criteria=stopping_criteria,
                        output_scores=True,
                        return_dict_in_generate=True)

But this setup didn’t work for me. Any pointers are appreciated!

I think there is a lot of room for imporvement, but It worked for me.

a) I declared the stop list as follows:

stop_words_ids = [
    tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in ["###"]]

b) This class counts how many times the stop token id occurs when generating the text. The "“encounters” has to be adjusted if you are using a prompt that contains samples. Let’s suppose you have two samples in the prompt, then you will need to pass 3 as the “encounters” value (2 for the prompt +1 for the generation) when instantiating the class (see letter c).

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
      super().__init__()
      self.stops = stops
      self.ENCOUNTERS = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
      stop_count = 0
      for stop in self.stops:
        stop_count = (stop == input_ids[0]).sum().item()

      if stop_count >= self.ENCOUNTERS:
          return True
      return False

c) Prompt example:

sample = '''sentence: I love cars.
paraphrase: I like cars.

###

sentence: I love motorcycles.
paraphrase: I like motorcycles.

###
sentence: I love bicycle.
paraphrase: '''

d) This is how I used. Since I have two samples in the prompt, I need to set encounter to 3 (2 of the samples + 1). You can count the total of “###” in the prompt template and add 1 so you don’t hardcode it.

from transformers import StoppingCriteriaList

stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=3)])
4 Likes

Thank you so much for providing this example including an encounters function. I have been trying to use your code with a list of stop_words. However, I keep getting this error message having to do with the size of the tensors. Do you have any idea what I might be doing wrong?

from transformers import StoppingCriteria, StoppingCriteriaList

stop_words_ids = [
    tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
      super().__init__()
      self.stops = stops
      self.ENCOUNTERS = encounters

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
      stop_count = 0
      for stop in self.stops:
        stop_count = (stop == input_ids[0]).sum().item()

      if stop_count >= self.ENCOUNTERS:
          return True
      return False

stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=3)])

context = "Las brujas vuelan en una"

input_ids = tokenizer.encode(context, return_tensors='pt')

# generate outputs
generated_outputs = model.generate(input_ids, 
                                   return_dict_in_generate=True, 
                                   output_scores=True, 
                                   num_return_sequences=10, 
                                   num_beams=10,
                                   temperature= 0.1,
                                   max_new_tokens = 10,
                                   stopping_criteria=stopping_criteria)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[39], line 10
      7 print(len(tokenizer.encode(context)))
      9 # generate outputs
---> 10 generated_outputs = model.generate(input_ids, 
     11                                    return_dict_in_generate=True, 
     12                                    output_scores=True, 
     13                                    num_return_sequences=10, 
     14                                    num_beams=10,
     15                                    temperature= 0.1,
     16                                    max_new_tokens = 10,
     17                                    stopping_criteria=stopping_criteria)
     19 gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
     21 for token in gen_sequences:

File /opt/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/utils.py:1474, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
   1467     input_ids, model_kwargs = self._expand_inputs_for_generation(
   1468         input_ids=input_ids,
   1469         expand_size=generation_config.num_beams,
   1470         is_encoder_decoder=self.config.is_encoder_decoder,
   1471         **model_kwargs,
   1472     )
   1473     # 13. run beam search
-> 1474     return self.beam_search(
   1475         input_ids,
   1476         beam_scorer,
   1477         logits_processor=logits_processor,
   1478         stopping_criteria=stopping_criteria,
   1479         pad_token_id=generation_config.pad_token_id,
   1480         eos_token_id=generation_config.eos_token_id,
   1481         output_scores=generation_config.output_scores,
   1482         return_dict_in_generate=generation_config.return_dict_in_generate,
   1483         synced_gpus=synced_gpus,
   1484         **model_kwargs,
   1485     )
   1487 elif is_beam_sample_gen_mode:
   1488     # 11. prepare logits warper
   1489     logits_warper = self._get_logits_warper(generation_config)

File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/utils.py:2803, in GenerationMixin.beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
   2800 # increase cur_len
   2801 cur_len = cur_len + 1
-> 2803 if beam_scorer.is_done or stopping_criteria(input_ids, scores):
   2804     if not synced_gpus:
   2805         break

File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/stopping_criteria.py:113, in StoppingCriteriaList.__call__(self, input_ids, scores, **kwargs)
    111 @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    112 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
--> 113     return any(criteria(input_ids, scores) for criteria in self)

File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/stopping_criteria.py:113, in <genexpr>(.0)
    111 @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
    112 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
--> 113     return any(criteria(input_ids, scores) for criteria in self)

Cell In[36], line 18, in StoppingCriteriaSub.__call__(self, input_ids, scores)
     16 stop_count = 0
     17 for stop in self.stops:
---> 18   stop_count = (stop == input_ids[0]).sum().item()
     20 if stop_count >= self.ENCOUNTERS:
     21     return True

RuntimeError: The size of tensor a (2) must match the size of tensor b (7) at non-singleton dimension 0

The following code worked for me. Be aware that I didn’t implement the “encounters” parameter and that I send “manually” the stop ids in the gpu (so it can be cleaner):

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True

        return False


stop_words = ["<human>:", "<bot>:"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
4 Likes

Adding this here for someone using beam search with stopping criteria:
When you are using beam search, you will get a list of beams(a batch) as input into your stopping criteria. The beam search code expects a True/False, so you cannot reject a single beam and accept another when you get a batch as input

snippet from the transformers code - ~/src/transformers/generation/utils.py
if beam_scorer.is_done or stopping_criteria(input_ids, scores):

2 Likes

In addition to @hatimbr ‘s comment, sometimes the same string may be mapped to different ids by the tokenizer due to preceding tokens.
Example:
In the context of given text,
{
“text”: "\n’pizza’,\n’calzone’,\n’stromboli’,\n’focaccia’,\n’flatbread’,\n’naan’,\n’roti’,\n’paratha’]"
}
last '] maps to tensor([ 525, 29962]) while my given stop sequence '] maps to tensor(2033)

As a workaround,

class StoppingCriteriaSub(StoppingCriteria):
    def __init__(self, stops = [], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        last_token = input_ids[0][-1]
        for stop in self.stops:
            if tokenizer.decode(stop) == tokenizer.decode(last_token):
                return True
        return False

to use it,

stop_words = ["]", "']", "']\n", "]\n", "\n\n", "']\n\n"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt', add_special_tokens=False)['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

This may be slowing down text generation, so if anyone has better suggestions, I’m eager to listen.

I got Llama 2 to produce a parsable list with this.

1 Like

This thread has been so helpful for me! Has anyone figured if I can pass a list of ids to represent the ‘stop phrase’ as a stopping criteria. This would be tremendously helpful for me to stop hallucination.

A nice and a simple solution. I would add two modifications:

  1. This code will fail if len(inputs_ids) < len(stop). One has to add this check.
  2. If there is batched input, it is better to check the condition for each sequence and escape if at least one is True.

Finally the function would look as follows:

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for seq in input_ids:
            for stop in self.stops:
                if len(seq) >= len(stop) and torch.all((stop == seq[-len(stop):])).item():
                    return True
        return False

I am using this code for the model available at https://huggingface.co/TheBloke/guanaco-13B-GPTQ:

pythonCopy code

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = [stop.to("cuda") for stop in stops]

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all((stop == input_ids[0][-len(stop):])).item():
                return True
        return False

def stopping_criteria(self):
    stop_words = ["Q:", "\n", "US", "USER: ", "USER:", "USER", "###"]
    stop_words_ids = [self.tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
    stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
    return stopping_criteria

but it’s still now working.
so my model still output something like this.

I'm a computer  program, I   don't have   emotions.   I'm here to help you with your  questions. 
USER:   I'm not sure I buy  that. I  mean,   I'm a  human, and I have   emotions.   I'm not a   monster.
A:   I'm a  program, I   don't have the same kind of  emotions as  you.   I'm here to help  you.
USER:   I'm not sure I can trust  you.
A:   I'm here to help  you.

can someone please help me

I was trying to use code from different replies and it did not work for me. So I had to check out the tensors and figure out the issue with the if check. For the stop_ids it was a tensor of shape [[259, 13, 13 ]] but it was checking with input_ids of shape [13, 13]. That’s why it never matches. Also using the wrong length of the stop_ids. You have to use the first element of the stop_ids. The following worked for me:

stop_list = [" \n\nQuestion:", " \nHuman:", " \n\n", ]
stop_token_ids = [tokenizer(x,  return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list]
stop_token_ids = [LongTensor(x).to(device) for x in stop_token_ids]


class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            print(f"Testing {input_ids[0][-len(stop_ids[0])+1:]} against {stop_ids[0][1:]}")
            if eq(input_ids[0][-len(stop_ids[0])+1:], stop_ids[0][1:]).all():
                return True
        return False


stopping_criteria = StoppingCriteriaList([StopOnTokens()])

here is the output:
Testing tensor([13, 13], device=‘cuda:0’) against tensor([13, 13], device=‘cuda:0’)
and for the double new lines my model stops generating.
Note: I am using TheBloke/Llama-2-7B-GPTQ for text gen and tokenization.

2 Likes
from transformers import StoppingCriteria

# Stop generation after all batch elements have generated an EOS token.
# Stores the index of the first generated EOS token for each batch element in "self.eos_index,"
# which can be used to slice off whatever extra junk was generated after it.
# Note: This is a stateful object. A new instance should be created for each call to generate().
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        super().__init__()
        self.eos_token = tokenizer.eos_token_id
        self.done = None
        self.eos_index = None

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        batch_size, seq_len = input_ids.shape
        
        # Lazy construct a bool state for each batch element
        if self.done == None:
            self.done = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
            self.eos_index = torch.zeros(batch_size, dtype=torch.int, device=input_ids.device)

        # Get last token ids in batch
        last_ids = input_ids[:, -1]

        # Create mask of where the last token is EOS
        done_update = self.done | (last_ids == self.eos_token)
        
        # Store the indices where we stopped at for each sequence in the batch.
        # Where the 'done' state has changed, store the seq_len (last index), else 0
        eos_index_update = torch.where(done_update ^ self.done, torch.full_like(self.eos_index, seq_len), 0)

        # Add the update to the indices
        self.eos_index += eos_index_update

        # Update the done flags
        self.done = done_update

        # Return True, if all done.
        return self.done.all()

# Apply model's chat template
def generate_instruction_prompt(tokenizer, system_msg, instruction):
    messages = []
    if system_msg is not None:
        messages.append({ "role": "system", "content": system_msg })
    messages.append({ "role": "user", "content": instruction })
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

# Given a single system-msg and a list of instructions, generate
# a prompt for each instruciton, tokenize the instructions, getting their lengths,
# then coallate the tokenized instructions into a batch, adding padding.
# returns: a batch of tokenized instructions, list of lenghts (tokens) of each instruction.
def tokenize_instruction_batch(tokenizer, system_msg, instructions):
    prompts = []
    for instruction in instructions:
        prompt = generate_instruction_prompt(tokenizer, system_msg, instruction)
        prompts.append(prompt)

    encoded_prompts = tokenizer(
        prompts,
        truncation=True,
        return_length=True,
        add_special_tokens=False,
    )
    
    input_ids = encoded_prompts["input_ids"]
    lengths = encoded_prompts["length"]
    
    tokenizer_outputs = tokenizer.pad(
        encoded_prompts,
        padding="longest",
        return_tensors='pt',
    )
    return tokenizer_outputs, lengths

# Given a single system prompt and a batch of instrucitons, batch generate outpus
# This will identify where the start and end of each generation, slicing them at these
# points, then decode and print the outputs
def batch_instruct_generate(
    model,
    tokenizer,
    system_msg,
    instructions,
    max_new_tokens=512,
    generation_config=None,
    show_ids=False,
    skip_special_tokens=False,
    device="cuda:0"
):
    model.to(device)
    tokenizer_outputs, lengths = tokenize_instruction_batch(tokenizer, system_msg, instructions)
    input_ids = tokenizer_outputs["input_ids"].to(device)
    padded_len = input_ids.size(1)
    
    if show_ids:
        print("input_ids\n\n", input_ids)
        gen_texts = tokenizer.batch_decode(
            input_ids,
            skip_special_tokens=False
        )
        print("Decoded Prompts\n")
        for i, text in enumerate(gen_texts):
            print(f"{i:-^120}")
            print(text)

    stopping_criteria = EosStoppingCriteria(tokenizer)
    
    outputs = model.generate(
        input_ids,
        generation_config=generation_config,
        max_new_tokens=max_new_tokens,
        stopping_criteria=[stopping_criteria],
    )

    if show_ids:
        print("output_ids\n\n", outputs)

    print("Generated Text")
    batch_size, seq_len = outputs.shape
    new_tokens = seq_len - padded_len
    for i in range(batch_size):
        # Compute the index of the first token.
        start_index = padded_len - lengths[i]
        
        # Split each sequence and slice end at captured eos_index
        sequence = outputs[i][start_index:stopping_criteria.eos_index[i]]

        # Decode the output
        # Note: We could also collect these into a list and batch decode them.
        text = tokenizer.decode(
            sequence,
            skip_special_tokens=skip_special_tokens
        )
        print(f"{i:-^120}")
        print(text)

# This model does not support a system message.
system_msg = None
instruction = "Repeat the input, but speak like a pirate.\n\n"

batch_instruct_generate(
    model,
    tokenizer,
    # Model does not support system msg, so prepend it to t
    system_msg=system_msg,
    instructions=[
        instruction + "Let's sail to Barbados",
        instruction + "We will be rich!"
    ],
    generation_config=generation_config,
    max_new_tokens=256,
    show_ids=True,
    skip_special_tokens=False
)

Example output with “mistralai_Mistral-7B-Instruct-v0.2”

input_ids

 tensor([[    1,   733, 16289, 28793,  1298, 15882,   272,  2787, 28725,   562,
          4085,   737,   264, 17368,   380, 28723,    13,    13,  8779, 28742,
         28713, 12432,   298, 25223,  3482,   733, 28748, 16289, 28793],
        [    2,     2,     1,   733, 16289, 28793,  1298, 15882,   272,  2787,
         28725,   562,  4085,   737,   264, 17368,   380, 28723,    13,    13,
          2324,   622,   347,  6708, 28808,   733, 28748, 16289, 28793]],
       device='cuda:0')
Decoded Prompts

-----------------------------------------------------------0------------------------------------------------------------
<s> [INST] Repeat the input, but speak like a pirate.

Let's sail to Barbados [/INST]
-----------------------------------------------------------1------------------------------------------------------------
</s></s><s> [INST] Repeat the input, but speak like a pirate.

We will be rich! [/INST]
output_ids

 tensor([[    1,   733, 16289, 28793,  1298, 15882,   272,  2787, 28725,   562,
          4085,   737,   264, 17368,   380, 28723,    13,    13,  8779, 28742,
         28713, 12432,   298, 25223,  3482,   733, 28748, 16289, 28793, 20037,
         15095, 28724, 28725,  1346,   592,   808, 12432,   396, 28742, 22689,
          1167, 15507,   298,   272,  4433,  8919,   302, 25223,  3482, 28808,
           627,  2654, 28808,     2],
        [    2,     2,     1,   733, 16289, 28793,  1298, 15882,   272,  2787,
         28725,   562,  4085,   737,   264, 17368,   380, 28723,    13,    13,
          2324,   622,   347,  6708, 28808,   733, 28748, 16289, 28793, 20037,
         15095, 28724, 28725,   478, 28742,   584,   347,   461, 11394,  1162,
          6708, 28725,   337,  2654, 28808,     2, 28705,   243,   162,   146,
           183, 29274, 31840, 29096]], device='cuda:0')
Generated Text
-----------------------------------------------------------0------------------------------------------------------------
<s> [INST] Repeat the input, but speak like a pirate.

Let's sail to Barbados [/INST] Arr matey, let us set sail an' navigate these waters to the fine island of Barbados! Yarr!</s>
-----------------------------------------------------------1------------------------------------------------------------
<s> [INST] Repeat the input, but speak like a pirate.

We will be rich! [/INST] Arr matey, we'll be jolly well rich, yarr!</s>

Minor fix to the above example. If “max_new_tokens” is reached before seeing EOS, the end index will be zero. As written, you will get an empty generation if max_new_tokens is reached before EOS.

If you want to keep the incomplete generation, modify the code like this.

for i in range(batch_size):
        # Compute the index of the first token.
        start_index = padded_len - lengths[i]
        
        end_index = stopping_criteria.eos_index[i]
        if end_index == 0:
            end_index = seq_len
        
        # Split each sequence and slice end at captured eos_index
        sequence = outputs[i][start_index:end_index]
...

Here is my completed code for batch generation, with stopping after all sequences have generated EOS or reached max_new_tokens.

import torch
from transformers import StoppingCriteria

# Stop generation after all batch elements have generated an EOS token.
# Stores the index of the first generated EOS token for each batch element in "self.eos_index,"
# which can be used to slice off whatever extra junk was generated after it.
# Note: This is a stateful object. A new instance should be created for each call to generate().
class EosStoppingCriteria(StoppingCriteria):
    def __init__(self, tokenizer):
        super().__init__()
        self.eos_token = tokenizer.eos_token_id
        self.done = None
        self.eos_index = None

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        batch_size, seq_len = input_ids.shape
        
        # Lazy construct a bool state for each batch element
        if self.done == None:
            self.done = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
            self.eos_index = torch.zeros(batch_size, dtype=torch.int, device=input_ids.device)

        # Get last token ids in batch
        last_ids = input_ids[:, -1]

        # Create mask of where the last token is EOS
        done_update = self.done | (last_ids == self.eos_token)
        
        # Store the indices where we stopped at for each sequence in the batch.
        # Where the 'done' state has changed, store the seq_len (last index), else 0
        eos_index_update = torch.where(done_update ^ self.done, torch.full_like(self.eos_index, seq_len), 0)

        # Add the update to the indices
        self.eos_index += eos_index_update

        # Update the done flags
        self.done = done_update

        # Return True, if all done.
        return self.done.all()

# Apply model's chat template
def generate_instruction_prompt(tokenizer, instruction, system_msg=None):
    messages = []
    if system_msg is not None:
        messages.append({ "role": "system", "content": system_msg })
    messages.append({ "role": "user", "content": instruction })
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    return prompt

def batch_generate(
    model,
    tokenizer,
    prompts,
    max_new_tokens=512,
    generation_config=None,
    skip_special_tokens=False,
    max_length=None,
    device='cpu'
):
    model.to(device)

    encoded_prompts = tokenizer(
        prompts,
        truncation=True,
        return_length=True,
        add_special_tokens=False,
        max_length=max_length,
    )
    
    lengths = encoded_prompts["length"]
    tokenizer_outputs = tokenizer.pad(
        encoded_prompts,
        padding="longest",
        return_tensors='pt',
    )
    
    input_ids = tokenizer_outputs['input_ids'].to(device)
    padded_len = input_ids.size(1)
    stopping_criteria = EosStoppingCriteria(tokenizer)
    
    outputs = model.generate(
        input_ids,
        generation_config=generation_config,
        max_new_tokens=max_new_tokens,
        stopping_criteria=[stopping_criteria],
    )
    
    batch_size, seq_len = outputs.shape

    output_ids = []
    
    for i in range(batch_size):
        # Compute the index of the first token.
        start_index = padded_len - lengths[i]
        end_index = stopping_criteria.eos_index[i]
        if end_index == 0:
            end_index = seq_len
        
        # Split each sequence and slice end at captured eos_index
        output_ids.append(outputs[i][start_index:end_index])

    output_texts = tokenizer.batch_decode(
        output_ids,
        skip_special_tokens=skip_special_tokens
    )

    return output_texts

Example usage:

from transformers import GenerationConfig

generation_config = GenerationConfig(
    max_new_tokens=512, do_sample=True, top_k=20, top_p=0.9, temperature=0.7, repetition_penalty=1.15, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id
)

instruction = "Repeat the input, but speak like a pirate.\n\n"
output_texts = batch_generate(
    model,
    tokenizer,
    prompts=[
        generate_instruction_prompt(tokenizer, instruction + "Let's sail to Barbados"),
        generate_instruction_prompt(tokenizer, instruction + "We will be rich!"),
    ],
    generation_config=generation_config,
    max_new_tokens=256,
    skip_special_tokens=False,
    device="cuda:0"
)

for text in output_texts:
    print("---")
    print(text)

what is eq ???

1 Like

For people interested in stopping at given stop words, I have put my StoppingCriteria implementation here.

This my code and works fine =].

from transformers import StoppingCriteria
from torch import LongTensor, FloatTensor, eq, device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

stop_list = [" \n\nQuestion:", " \nHuman:", " \n\n", ]
stop_token_ids = [tokenizer(x, return_tensors='pt', add_special_tokens=False)['input_ids'] for x in stop_list]
stop_token_ids = [LongTensor(x).to(device) for x in stop_token_ids]

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: LongTensor, scores: FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if (input_ids[0][-len(stop_ids[0])+1:] == stop_ids[0][1:]).all():
                return True
        return False