I’m using model.generate()
for text generation.
I’m wondering is there a way to whitelist specific tokens so that they are returned during the beam search phase.
For example - I want to the “force” the response to contain a question mark or a speccific phrase
Every token is “whitelisted” in the sense that it is considered during beam search. However, you can modify/boost certain tokens you would like to see generated, using the LogitsProcessor. Just implement your own class that boosts your favourite tokens or sequence of tokens. If you are looking for a question mark - maybe you can “blacklist” other end of sentence punctuations in your LogitsProcessor, too?
Thanks @100worte,
I’m new to transformers and I’m still trying to understand how to use the LogitProcessor
.
Do I just manually augment the logit scores
returned?
How do pass my custom class into model.generate
?
You can pass a LogitsProcessorList
to beam_search
. For some concrete examples of LogitProcessor
s, refer to generation_logits_process.py
.
Is it possible to pass custom LogitsProcessor
to model.generate()
?