I was wondering if huggingface provided any support for exporting the generate function from the transformers library to ONNX?
Mainly, I was trying to create an ONNX model using a GPT2 style transformer in order to speed up inference when generating replies to a conversation.
I see there’s some support for exporting a single call to GPT2, but not the entire for loop used in greedy decoding/beam search/nucleus sampling etc.
Thanks for the response @guillermogabrielli. I’ve already taken a cursory look at optimum, however all the export scenarios seem pretty targeted. I could not find a way to export an LM + generalized decoder.
I want to export both the model and the generation strategy. The model architecture is similar to GPT with slight modifications. And the generation strategy, including greedy and beam search, requires pytorch to trace if and for.
I very much doubt that transformers generate() will ever be jit scriptable. Therefore, the solution would probably be to rewrite generate() and greedy and beam search for them to be jit.scriptable. The other option is to use ONNX Runtime GreedySearch and BeamSearch ops, but those are usable only with CPUExecutionProvider and CUDAExecutionProvider (useless if you want to use TensorRT for example).
You could torch.jit.trace the model itself, and then torch.jit.script the generate method.
If we use torch.jit.trace function to convert it to TorchScript. How can we trace two different forward passes? (One takes a real mem variable which is a List[Tensor], and the other takes a None)
For generation, I think you will not be able to obtain a meaningful ScriptModule using torch.jit.trace. This is because the number of generation steps and early stopping depend on which tokens are generated, the number of loops is dynamic. torch.jit.trace would just hardcode the number of loops from the sample example provided during tracing, but this is not useful for other samples.
If we use torch.jit.trace function to convert it to TorchScript. How can we trace two different forward passes?
Normally you should just provide the model with a different set of inputs that trigger the controlflows you would like to use.