How to create a custom decoding strategy in the GenerationMixin class?

I am currently doing research on auto-regressive text generation, and I would like to add a custom decoding strategy in the GenerationMixin class which can be called by:


Any pointers (e.g., how to subclass GenerationMixin) would be much appreciated!

@patrickvonplaten @sgugger

Hi @wlchen :wave:

First things first: you can implement a stand-alone decoding strategy and it will work (e.g. you can easily implement a stand-alone greedy_search). Subclassing in GenerationMixin is not a hard requirement.

The main advantage of subclassing would be the tight integration with .generate(), which does a lot of input preparation and checks under the hood. From .generate(), the individual decoding strategies are then called.

My suggestion:

  1. Add a set of conditions to trigger your decoding strategy from .generate(), see here for examples.
  2. Copy-paste the existing greedy_search (or whichever decoding strategy is the most similar with the one you’re trying to implement), and adapt from there :slight_smile:

Hi @joaogante

Thanks for the clear and detailed explanation!
I will try my implementation with your suggestion :smiley:

1 Like