Whitelist specific tokens in beam search

I’m using model.generate() for text generation.
I’m wondering is there a way to whitelist specific tokens so that they are returned during the beam search phase.
For example - I want to the “force” the response to contain a question mark or a speccific phrase

Every token is “whitelisted” in the sense that it is considered during beam search. However, you can modify/boost certain tokens you would like to see generated, using the LogitsProcessor. Just implement your own class that boosts your favourite tokens or sequence of tokens. If you are looking for a question mark - maybe you can “blacklist” other end of sentence punctuations in your LogitsProcessor, too?

1 Like

Thanks @100worte,
I’m new to transformers and I’m still trying to understand how to use the LogitProcessor.
Do I just manually augment the logit scores returned?
How do pass my custom class into model.generate?

You can pass a LogitsProcessorList to beam_search. For some concrete examples of LogitProcessors, refer to generation_logits_process.py.

1 Like

Is it possible to pass custom LogitsProcessor to model.generate()?