Stopping criteria for batch

I’ve successfully implemented and used my own StoppingCriteria.

However I came across one limitation : the stopping criteria is applied to the whole batch.
i.e. I can’t mark specific sample of the batch as “stopped” and other as “not stopped”, I have to return a single boolean for the whole batch.

But I want to have more fine-grained control over which sample is stopped or not, so when used with beam search, I can early-terminate some beam, and keep generating for other beams.

How can I achieve that (if it’s possible) ?

Currently the __call__() function of my custom stopping criteria look like :

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        return all(self.is_stop(sample) for sample in input_ids)

Instead of returning a boolean, I’d like to return a tensor of boolean.
From what I saw in the source code, it seems it’s not implemented yet…

1 Like

Have you managed to solve this problem?
I am facing the same problem and would also be interested in a solution :slight_smile:

Nope I didn’t solve this problem.

My work-around is to keep generating until the whole batch match the stopping criteria, and then post-process the generated text to fit what I wanted.

So in my case I wanted to stop generation whenever a space is generated. I stopped the generation once ALL samples generated a space. And then post-processed them to keep only the first word (anything after the first space is generated was discarded).

1 Like