Probably too late but just in case it helps someone else - the following code worked for me:
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from genre.trie import MarisaTriemodel = AutoModelForSeq2SeqLM.from_pretrained(ât5-baseâ)
tokenizer = AutoTokenizer.from_pretrained(ât5-baseâ)trie = MarisaTrie([[0]+tokenizer.encode(âHello Worldâ)])
output = model.generate(tokenizer.encode(âHello Worldâ, return_tensors=âptâ), prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()))
The above snipped will always produce âHello Worldâ as the output. You can also include multiple strings when creating the Marisa trie.
The [0] is required at the start as the t5 model always produces 0 as the first token.
The definition of the trie is taken from here: https://github.com/facebookresearch/GENRE/blob/main/genre/trie.py and the file requires âpip install marisa-trieâ to be installed in the environment.