Do you have a full reusable piece of code for that?
1 Like
With the newer versions you might not need this.
Just specify the eos_token_id in the generate function of the huggingface.
required for llama3.
terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]
model.generate(..., eos_token_id=terminators)
1 Like
Thanks, that is a lot nicer than I was expecting.
Just a thing to note, generated characters will often be embedded into other tokens.
For example, tokens and token id for tokens containing “a”:
[('<|pad_0|>', 32001),
('Abstract', 16384),
('Adapter', 8683),
('Alias', 20766),
('Alpha', 19510),
('Amaz', 26798),
('America', 25696),
('American', 12188),
So it may be more complete to do something like:
char = "a"
# identify all tokens containing char
terminators = [
token_id for token, token_id in tokenizer.vocab.items()
if char in token
]
terminators.append(tokenizer.eos_token_id)
model.generate(..., eos_token_id=terminators)
1 Like
class StopOnString(StoppingCriteria):
def __init__(self, stop_string, tokenizer):
self.stop_string = stop_string
self.tokenizer = tokenizer
self.stream = ""
def reset(self):
self.stream = ""
def __call__(self, input_ids, scores, **kwargs):
generated = self.tokenizer.decode(input_ids[0][-1], skip_special_tokens=True)
self.stream += generated
if self.stream.endswith(self.stop_string):
return True
print(generated, end="", flush=True)
return False
This also contains a streamer.
It just does tokenizer.decode(input_ids)
1 Like
I verified this! It was changed in January 2023 which meant that v 4.26.0 got it (confirmed)
eos_token_id (`Union[int, List[int]]`, *optional*):
Formerly it was int
typed
1 Like