I’m trying to write a StoppingCriteria function that tests the most probable prediction and stops if it’s not in a set of prioritized token ids, but I’m seeing that scores
is always None
. Is there a reason for this, or am I doing something wrong? I wrote my stopping criteria following examples from the original source:
class TopPredictionOutsideTargetSetStoppingCriteria(StoppingCriteria):
def __init__(self, priority_tokens_ids: list):
self.priority_token_ids = priority_tokens_ids
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
print("TopPred SCORES? ", scores) // <-- always None
top = torch.topk(scores, 1, dim=1).indices[0]
if not top in self.priority_token_ids:
return True
return False
and I set it up like this:
st_stopping = TopPredictionOutsideTargetSetStoppingCriteria(priority_ids)
breaking_ids = list(range(9, 21))
bt_token_stopping = InfillEndingStoppingCriteria(self.input_token_ids[0], 10, breaking_ids)
stop = StoppingCriteriaList([st_stopping, bt_token_stopping])
Then I just include stopping_criteria=stop
in the generate()
call.
Is there some way to tell it to pass scores
to the function, assuming it isn’t passed automatically?
Thanks in advance.