What's the recommended way to apply transforms to Flax models for generation?

It’s pretty easy to transform a HuggingFace Flax model at training time since you can just apply whatever transformations you want to your loss function which makes use of a Flax model. However, if you want to use a generation utility, what’s the recommended way to transform the model’s forward pass?