Using huggingface generate() with custom model

Hi there, I’m wondering if there is an elegant way to use huggingface’s generate() function given a customized model.

Currently, I’m using the MBartForConditionalGeneration class and I want to change how the inputs are processed inside the forward function e.g. adding stuffs such as variational inference from a prior distribution. I intend to subclass the MBartForConditionalGeneration with additional properties and methods to accomplish this. Is this the right way or is there a better approach out there ?

Thanks in advance !

If your custom class ultimately just changes how the logits are computed, then I think what you’ve said is the right approach. You’ll just need to make sure the forward method’s outputs match the same format as the normal model.

On the other hand, if you need a custom sampling strategy, you’ll likely also need to override and modify the sample method or create a custom LogitsWarper.

1 Like