Iām using model.generate()
for text generation.
Iām wondering is there a way to whitelist specific tokens so that they are returned during the beam search phase.
For example - I want to the āforceā the response to contain a question mark or a speccific phrase
Every token is āwhitelistedā in the sense that it is considered during beam search. However, you can modify/boost certain tokens you would like to see generated, using the LogitsProcessor. Just implement your own class that boosts your favourite tokens or sequence of tokens. If you are looking for a question mark - maybe you can āblacklistā other end of sentence punctuations in your LogitsProcessor, too?
Thanks @100worte,
Iām new to transformers and Iām still trying to understand how to use the LogitProcessor
.
Do I just manually augment the logit scores
returned?
How do pass my custom class into model.generate
?
You can pass a LogitsProcessorList
to beam_search
. For some concrete examples of LogitProcessor
s, refer to generation_logits_process.py
.
Is it possible to pass custom LogitsProcessor
to model.generate()
?