I am currently doing research on auto-regressive text generation, and I would like to add a custom decoding strategy in the GenerationMixin class which can be called by:
model.my_decoding_strategy(...)
Any pointers (e.g., how to subclass GenerationMixin) would be much appreciated!
First things first: you can implement a stand-alone decoding strategy and it will work (e.g. you can easily implement a stand-alone greedy_search). Subclassing in GenerationMixin is not a hard requirement.
The main advantage of subclassing would be the tight integration with .generate(), which does a lot of input preparation and checks under the hood. From .generate(), the individual decoding strategies are then called.
My suggestion:
Add a set of conditions to trigger your decoding strategy from .generate(), see here for examples.
Copy-paste the existing greedy_search (or whichever decoding strategy is the most similar with the one you’re trying to implement), and adapt from there