Implimentation of Stopping Criteria List

Hi!

This is my first reply ever, so please don’t judge too strictly :slight_smile:

I was able to solve a similar task today (for GPT-2), so maybe my suggestion will help you.

So, first things first:

  1. Q: “The instructions seem to use the Bert tokeniser - to generate tokens of the stop sequence?”
    A: in order to stop the sequence the model should know the token that should be used for stopping. So, the tokenizer.encode is used for you to see what token/sequence of tokens will correspond to your word (e.g. “foo bar”) (e.g.:
stop_words_ids = [
    tokenizer.encode(stop_word, add_prefix_space = False) for stop_word in ["foo", "bar"]]

)

  1. Q: “I am trying to implement this with the OPT model (13b) - would I still use the BERT tokeniser?”
    A: No, you should use the tokenizer from your respective model to get the correct tokens. E.G. here are the token values that I’ve got for the words “foo, bar” using a tokenizer from my model which I’m currently training: [[21943], [5657]]. The tokens will be different for different models, that’s why you should use the tokenizer for your model e.g.:
tokenizer = 
AutoTokenizer.from_pretrained("facebook/opt-13b", use_fast=False))
  1. Q: “Would anyone be able to show an example of using this successfully?”

Here is what you should do once you know the token IDs to use for stopping:
a) import required modules: from transformers import StoppingCriteria, StoppingCriteriaList
b) subclass the StoppingCriteria class and add to it a new functionality:

class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops = []):
      StoppingCriteria.__init__(self), 

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops = []):
      self.stops = stops
      for i in range(len(stops)):
        self.stops = self.stops[i]

c) instantiate the class (and pass the tokens which you want to use for stopping as an argument):

stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = [[21943], [5657]])])

d) finally, pass stopping_criteria as an argument to model.generate:
model.generate(input_ids, do_sample=True, stopping_criteria=stopping_criteria)

I hope, it is helpful

3 Likes