Hello, I would like to use the prefix_allowed_tokens_fn
as an input to the model.generate()
function, in order to perform constrained text generation with BART.
length_penalty: Optional[float] = None,
no_repeat_ngram_size: Optional[int] = None,
encoder_no_repeat_ngram_size: Optional[int] = None,
num_return_sequences: Optional[int] = None,
max_time: Optional[float] = None,
max_new_tokens: Optional[int] = None,
decoder_start_token_id: Optional[int] = None,
use_cache: Optional[bool] = None,
num_beam_groups: Optional[int] = None,
diversity_penalty: Optional[float] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_scores: Optional[bool] = None,
return_dict_in_generate: Optional[bool] = None,
forced_bos_token_id: Optional[int] = None,
forced_eos_token_id: Optional[int] = None,
remove_invalid_values: Optional[bool] = None,
synced_gpus: Optional[bool] = None,
**model_kwargs,
) -> Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, torch.LongTensor]:
I tried to adapt the function in the original repository here , but it doesnât seem to be working. Can you please tell me if there are any examples of the kinds of functions that can be given as input to this parameter? Thank you!
3 Likes
Probably too late but just in case it helps someone else - the following code worked for me:
from transformers import AutoConfig, AutoModelForSeq2SeqLM, AutoTokenizer
from genre.trie import MarisaTrie
model = AutoModelForSeq2SeqLM.from_pretrained(ât5-baseâ)
tokenizer = AutoTokenizer.from_pretrained(ât5-baseâ)
trie = MarisaTrie([[0]+tokenizer.encode(âHello Worldâ)])
output = model.generate(tokenizer.encode(âHello Worldâ, return_tensors=âptâ), prefix_allowed_tokens_fn=lambda batch_id, sent: trie.get(sent.tolist()))
The above snipped will always produce âHello Worldâ as the output. You can also include multiple strings when creating the Marisa trie.
The [0] is required at the start as the t5 model always produces 0 as the first token.
The definition of the trie is taken from here: https://github.com/facebookresearch/GENRE/blob/main/genre/trie.py and the file requires âpip install marisa-trieâ to be installed in the environment.
3 Likes
Thank you so much for this answer, it is exactly what Iâve been searching for!