Dynamic decoder token masking

I am currently trying to train a Sequence to Sequence model with constraints on the decoder and decoder only. For each training instance, I have a list of tokens that I do not want to see in the decoder output. This list is different for each input sequence.

I basically want to reduce the vocabulary of the decoder to a specific subset of the vocabulary and this subset is different for each input document.

I saw that in the Generation Config, you can pass a list of “suppress tokens” that gets the logits at -inf for the specified tokens. But this does not seem appropriate when the list varies a lot from one document to another.

Any idea on how to do this please? Is this even possible?