I have used the following code for defining the stopping criteria for Llama2
from transformers import StoppingCriteria, StoppingCriteriaList
# define custom stopping criteria object
class StopOnTokens(StoppingCriteria):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
for stop_ids in stop_token_ids:
if torch.eq(input_ids[0][-len(stop_ids):], stop_ids).all():
return True
return False
stopping_criteria = StoppingCriteriaList([StopOnTokens()])
This is my stop_list = ['\nHuman: ', â\n```\nâ, '\nAI: ', â\nQuestionâ]
However to my question âWho is the CEO of Meta?â, llama2 doesnât stop on any of these stop tokens.
Also attaching the code for conversion of tokens to longtensor
stop_token_ids = [tokenizer(x, return_tensors='pt')['input_ids'].squeeze() for x in stop_list]
stop_token_ids
stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
stop_token_ids
Any help is appreciated! Thanks!