Hello team!
Is there a way to constrain the LM to only decode certain tokens at certain positions? For example, assume I need model to decode 6 tokens, indexed by 0 to 5. The tokens at positions (0, 2, 4) should be from vocabulary_1, these at positions (1, 3, 5) should be from vocabulary_2.
vocabulary_1 and vocabulary_2 are disjoint set of tokens, both are subsets of the training vocabulary.
Any suggestion is welcome!
Update:
This can be simplified a bit for my case. vocabulary_1 is actually the first half of the training vocabulary, vocabulary_2 is the 2nd half. So ideally at positions (0, 2, 4), I can mask logits of shape (1, |V|) with logits[:, -|V|/2:] = float(“-inf”), and similarly for positions (1, 3, 5) to mask another half. Wanna ask how this can be easily done with transformers generate function.