Custom Decoding Strategy


I am currently working on a project that involves controllable text generation.
To do this, I am making a ‘post-processing module’ which involves reranking the logits during the decoding process.

To start, I have forked transformers, and added a new GenerationMode to the GenerationMixin.
(GenerationMode.TEST_GREEDY_SEARCH) for this example.

Within the greedy_search(), I believe next_tokens_scores would be the logits scores for the tokens? I am not sure.

One of the problems is I need to decode the tokens of the logits into the human readable word to use it in my post processing module. At the moment I am just importing the tokeniser I need into the function, but I do not know how to get the current tokens.

In summary, I need the token, and its probability, for each of the logits.

Any help would be greatly appreciated.