Stopping criteria for batch

I’ve successfully implemented and used my own StoppingCriteria.

However I came across one limitation : the stopping criteria is applied to the whole batch.
i.e. I can’t mark specific sample of the batch as “stopped” and other as “not stopped”, I have to return a single boolean for the whole batch.

But I want to have more fine-grained control over which sample is stopped or not, so when used with beam search, I can early-terminate some beam, and keep generating for other beams.

How can I achieve that (if it’s possible) ?

Currently the __call__() function of my custom stopping criteria look like :

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        return all(self.is_stop(sample) for sample in input_ids)

Instead of returning a boolean, I’d like to return a tensor of boolean.
From what I saw in the source code, it seems it’s not implemented yet…

2 Likes

Have you managed to solve this problem?
I am facing the same problem and would also be interested in a solution :slight_smile:

Nope I didn’t solve this problem.

My work-around is to keep generating until the whole batch match the stopping criteria, and then post-process the generated text to fit what I wanted.

So in my case I wanted to stop generation whenever a space is generated. I stopped the generation once ALL samples generated a space. And then post-processed them to keep only the first word (anything after the first space is generated was discarded).

1 Like

I meet the problem too. It’s really make me confusion.

How about using a LogitsProcessor instead of a StoppingCriteria ? The idea is to give the <eos> and <pad> tokens an inf logit while giving all other tokens a -inf logit when the stopping criteria is met. This way, tokens generated after the stopping criteria is met will only be the <eos> token. However, this approach does not affect other sequences that have not yet met the stopping criteria.

2 Likes

Thanks for your reply ! I tried it and it works perfectly, much nicer than my custom StoppingCriteria.

For future reference, this is what my code look like (in my case I needed to stop generation once the model generated a space, but force at least one character generation at the beginning (I’m using a character-based transformer)) :

class StopAfterSpaceIsGenerated(LogitsProcessor):
        """Logits processor (to use with HuggingFace `generate()` method :
        https://huggingface.co/docs/transformers/v4.24.0/en/main_classes/
        text_generation#transformers.generation_utils.GenerationMixin).

        This logit processor simply ensure that we generate at least one letter
        other than space, and that we don't generate anything after generating
        a space (in order to generate single word).

        Args:
            base_len (int): Size of the given context. Used to know if this is
                the first character to generate.
            sp_token_id (int): ID of the space token.
            eos_token_id (int): ID of the EOS token.
        """
        def __init__(self, base_len: int, eos_token_id: int):
            super().__init__()

            self.base_len = base_len

            self.sp_token_id = sp_token_id
            self.eos_token_id = eos_token_id

        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            if input_ids.size(1) > self.base_len:
                forced_eos = torch.full((scores.size(1),), -float("inf"))
                forced_eos[self.eos_token_id] = 0
                
                # Force generation of EOS after a space
                scores[input_ids[:, -1] == self.sp_token_id] = forced_eos
            return scores

Then it can be used like this :

from transformers import LogitsProcessorList

logits_processor = LogitsProcessorList([StopAfterSpaceIsGenerated(base_len, 35, self.model.config.eos_token_id)])
model.generate(
                inputs["input_ids"],
                logits_processor=logits_processor,
)
1 Like

@heekang how do I use inf with sampling (do_sample=True). The approach works fine without sampling, but errors with sampling enabled. See Invalidate beam in do_sample mode with LogitsProcessor by setting it to -inf for a reproducible example