Exclude words from GPT-2 generate( )


I want to exclude some ids of the GPT-2 vocabulary from the generate() function, e.i. I want when the model generates the next word, not to be able to use a word from a list of words. How can I achieve that?

Thank you in advance.


I seems that this functionality is supported using the bad_words_ids input to the generate API. The docs briefly describe that you need to find a list of integers for the words you care about using the tokenizer and then simply pass those to generate:

**bad_words_ids** ( `List[List[int]]` , optional) – List of token ids that are not allowed to be generated. In order to get the tokens of the words that should not appear in the generated text, use `tokenizer(bad_word, add_prefix_space=True).input_ids` .

I hope this helps!

Thank you @deathcrush , it appears that this is exactly what I want.