Hi!
This is my first reply ever, so please don’t judge too strictly
I was able to solve a similar task today (for GPT-2
), so maybe my suggestion will help you.
So, first things first:
-
Q: “The instructions seem to use the Bert tokeniser - to generate tokens of the stop sequence?”
A: in order to stop the sequence the model should know thetoken
that should be used for stopping. So, thetokenizer.encode
is used for you to see what token/sequence of tokens will correspond to your word (e.g. “foo bar”) (e.g.:
stop_words_ids = [
tokenizer.encode(stop_word, add_prefix_space = False) for stop_word in ["foo", "bar"]]
)
-
Q: “I am trying to implement this with the OPT model (13b) - would I still use the BERT tokeniser?”
A: No, you should use the tokenizer from your respective model to get the correct tokens. E.G. here are the token values that I’ve got for the words “foo, bar” using a tokenizer from my model which I’m currently training: [[21943], [5657]]. The tokens will be different for different models, that’s why you should use the tokenizer for your model e.g.:
tokenizer =
AutoTokenizer.from_pretrained("facebook/opt-13b", use_fast=False))
- Q: “Would anyone be able to show an example of using this successfully?”
Here is what you should do once you know the token IDs to use for stopping:
a) import required modules: from transformers import StoppingCriteria, StoppingCriteriaList
b) subclass the StoppingCriteria
class and add to it a new functionality:
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = []):
StoppingCriteria.__init__(self),
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, stops = []):
self.stops = stops
for i in range(len(stops)):
self.stops = self.stops[i]
c) instantiate the class (and pass the tokens which you want to use for stopping as an argument):
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops = [[21943], [5657]])])
d) finally, pass stopping_criteria
as an argument to model.generate
:
model.generate(input_ids, do_sample=True, stopping_criteria=stopping_criteria)
I hope, it is helpful