Support for exporting generate function to ONNX?

I was wondering if huggingface provided any support for exporting the generate function from the transformers library to ONNX?

Mainly, I was trying to create an ONNX model using a GPT2 style transformer in order to speed up inference when generating replies to a conversation.

I see there’s some support for exporting a single call to GPT2, but not the entire for loop used in greedy decoding/beam search/nucleus sampling etc.

Take a look at huggingface optimum, but not all models are supported. Documentation.

Thanks for the response @guillermogabrielli. I’ve already taken a cursory look at optimum, however all the export scenarios seem pretty targeted. I could not find a way to export an LM + generalized decoder.

Hi @nifarn , could you elaborate on what you would like to see that is not available in Optimum for encoder-decoder models?

Hi @fxmarty

I have the same demand.

I want to export both the model and the generation strategy. The model architecture is similar to GPT with slight modifications. And the generation strategy, including greedy and beam search, requires pytorch to trace if and for.

I found a demo to export BART along with BeamSearch into ONNX. But it’s outdated.

Any suggestion on how to achieve this in Optimum?

@nifarn Could you share you final solution?

Any help is appreciated!

Hi @in-certo , this is a very reasonable feature request indeed. The issue is tracked at Make ORTModel PyTorch free · Issue #526 · huggingface/optimum · GitHub

I very much doubt that transformers generate() will ever be jit scriptable. Therefore, the solution would probably be to rewrite generate() and greedy and beam search for them to be jit.scriptable. The other option is to use ONNX Runtime GreedySearch and BeamSearch ops, but those are usable only with CPUExecutionProvider and CUDAExecutionProvider (useless if you want to use TensorRT for example).

You could torch.jit.trace the model itself, and then torch.jit.script the generate method.

Hi @fxmarty

Thanks for your suggestion.

I have some questions on how to convert a generation model to TorchScript.

Take GLM as an example. The signature of its forward function contains an Optional[Tensor] variable mem. The code can be found here: modeling_glm.py · BAAI/glm-10b at main (huggingface.co).

If we use torch.jit.trace function to convert it to TorchScript. How can we trace two different forward passes? (One takes a real mem variable which is a List[Tensor], and the other takes a None)

For generation, I think you will not be able to obtain a meaningful ScriptModule using torch.jit.trace. This is because the number of generation steps and early stopping depend on which tokens are generated, the number of loops is dynamic. torch.jit.trace would just hardcode the number of loops from the sample example provided during tracing, but this is not useful for other samples.

If we use torch.jit.trace function to convert it to TorchScript. How can we trace two different forward passes?

Normally you should just provide the model with a different set of inputs that trigger the controlflows you would like to use.

1 Like