Example of prefix_allowed_tokens_fn() while text generation

Probably too late but just in case it helps someone else - the following code worked for me:

from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from genre.trie import MarisaTrie

model = AutoModelForSeq2SeqLM.from_pretrained(‘t5-base’)
tokenizer = AutoTokenizer.from_pretrained(‘t5-base’)

trie = MarisaTrie([[0]+tokenizer.encode(‘Hello World’)])

output = model.generate(tokenizer.encode(‘Hello World’, return_tensors=‘pt’), prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()))

The above snipped will always produce “Hello World” as the output. You can also include multiple strings when creating the Marisa trie.

The [0] is required at the start as the t5 model always produces 0 as the first token.

The definition of the trie is taken from here: https://github.com/facebookresearch/GENRE/blob/main/genre/trie.py and the file requires ‘pip install marisa-trie’ to be installed in the environment.

4 Likes