Thank you so much for providing this example including an encounters function. I have been trying to use your code with a list of stop_words
. However, I keep getting this error message having to do with the size of the tensors. Do you have any idea what I might be doing wrong?
from transformers import StoppingCriteria, StoppingCriteriaList
stop_words_ids = [
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
class StoppingCriteriaSub(StoppingCriteria):
def __init__(self, stops = [], encounters=1):
super().__init__()
self.stops = stops
self.ENCOUNTERS = encounters
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
stop_count = 0
for stop in self.stops:
stop_count = (stop == input_ids[0]).sum().item()
if stop_count >= self.ENCOUNTERS:
return True
return False
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=3)])
context = "Las brujas vuelan en una"
input_ids = tokenizer.encode(context, return_tensors='pt')
# generate outputs
generated_outputs = model.generate(input_ids,
return_dict_in_generate=True,
output_scores=True,
num_return_sequences=10,
num_beams=10,
temperature= 0.1,
max_new_tokens = 10,
stopping_criteria=stopping_criteria)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[39], line 10
7 print(len(tokenizer.encode(context)))
9 # generate outputs
---> 10 generated_outputs = model.generate(input_ids,
11 return_dict_in_generate=True,
12 output_scores=True,
13 num_return_sequences=10,
14 num_beams=10,
15 temperature= 0.1,
16 max_new_tokens = 10,
17 stopping_criteria=stopping_criteria)
19 gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
21 for token in gen_sequences:
File /opt/anaconda3/lib/python3.8/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
24 @functools.wraps(func)
25 def decorate_context(*args, **kwargs):
26 with self.clone():
---> 27 return func(*args, **kwargs)
File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/utils.py:1474, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, **kwargs)
1467 input_ids, model_kwargs = self._expand_inputs_for_generation(
1468 input_ids=input_ids,
1469 expand_size=generation_config.num_beams,
1470 is_encoder_decoder=self.config.is_encoder_decoder,
1471 **model_kwargs,
1472 )
1473 # 13. run beam search
-> 1474 return self.beam_search(
1475 input_ids,
1476 beam_scorer,
1477 logits_processor=logits_processor,
1478 stopping_criteria=stopping_criteria,
1479 pad_token_id=generation_config.pad_token_id,
1480 eos_token_id=generation_config.eos_token_id,
1481 output_scores=generation_config.output_scores,
1482 return_dict_in_generate=generation_config.return_dict_in_generate,
1483 synced_gpus=synced_gpus,
1484 **model_kwargs,
1485 )
1487 elif is_beam_sample_gen_mode:
1488 # 11. prepare logits warper
1489 logits_warper = self._get_logits_warper(generation_config)
File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/utils.py:2803, in GenerationMixin.beam_search(self, input_ids, beam_scorer, logits_processor, stopping_criteria, max_length, pad_token_id, eos_token_id, output_attentions, output_hidden_states, output_scores, return_dict_in_generate, synced_gpus, **model_kwargs)
2800 # increase cur_len
2801 cur_len = cur_len + 1
-> 2803 if beam_scorer.is_done or stopping_criteria(input_ids, scores):
2804 if not synced_gpus:
2805 break
File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/stopping_criteria.py:113, in StoppingCriteriaList.__call__(self, input_ids, scores, **kwargs)
111 @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
112 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
--> 113 return any(criteria(input_ids, scores) for criteria in self)
File /opt/anaconda3/lib/python3.8/site-packages/transformers/generation/stopping_criteria.py:113, in <genexpr>(.0)
111 @add_start_docstrings(STOPPING_CRITERIA_INPUTS_DOCSTRING)
112 def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
--> 113 return any(criteria(input_ids, scores) for criteria in self)
Cell In[36], line 18, in StoppingCriteriaSub.__call__(self, input_ids, scores)
16 stop_count = 0
17 for stop in self.stops:
---> 18 stop_count = (stop == input_ids[0]).sum().item()
20 if stop_count >= self.ENCOUNTERS:
21 return True
RuntimeError: The size of tensor a (2) must match the size of tensor b (7) at non-singleton dimension 0