Example of prefix_allowed_tokens_fn() while text generation

Hello, I would like to use the prefix_allowed_tokens_fn as an input to the model.generate() function, in order to perform constrained text generation with BART.

I tried to adapt the function in the original repository here, but it doesn’t seem to be working. Can you please tell me if there are any examples of the kinds of functions that can be given as input to this parameter? Thank you!

3 Likes

Probably too late but just in case it helps someone else - the following code worked for me:

from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from genre.trie import MarisaTrie

model = AutoModelForSeq2SeqLM.from_pretrained(‘t5-base’)
tokenizer = AutoTokenizer.from_pretrained(‘t5-base’)

trie = MarisaTrie([[0]+tokenizer.encode(‘Hello World’)])

output = model.generate(tokenizer.encode(‘Hello World’, return_tensors=‘pt’), prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()))

The above snipped will always produce “Hello World” as the output. You can also include multiple strings when creating the Marisa trie.

The [0] is required at the start as the t5 model always produces 0 as the first token.

The definition of the trie is taken from here: https://github.com/facebookresearch/GENRE/blob/main/genre/trie.py and the file requires ‘pip install marisa-trie’ to be installed in the environment.

3 Likes

Thank you so much for this answer, it is exactly what I‘ve been searching for!