I’m trying to write a StoppingCriteria function that tests the most probable prediction and stops if it’s not in a set of prioritized token ids, but I’m seeing that
scores is always
None. Is there a reason for this, or am I doing something wrong? I wrote my stopping criteria following examples from the original source:
class TopPredictionOutsideTargetSetStoppingCriteria(StoppingCriteria): def __init__(self, priority_tokens_ids: list): self.priority_token_ids = priority_tokens_ids def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: print("TopPred SCORES? ", scores) // <-- always None top = torch.topk(scores, 1, dim=1).indices if not top in self.priority_token_ids: return True return False
and I set it up like this:
st_stopping = TopPredictionOutsideTargetSetStoppingCriteria(priority_ids) breaking_ids = list(range(9, 21)) bt_token_stopping = InfillEndingStoppingCriteria(self.input_token_ids, 10, breaking_ids) stop = StoppingCriteriaList([st_stopping, bt_token_stopping])
Then I just include
stopping_criteria=stop in the
Is there some way to tell it to pass
scores to the function, assuming it isn’t passed automatically?
Thanks in advance.