Thanks for the response @guillermogabrielli. I’ve already taken a cursory look at optimum, however all the export scenarios seem pretty targeted. I could not find a way to export an LM + generalized decoder.
I want to export both the model and the generation strategy. The model architecture is similar to GPT with slight modifications. And the generation strategy, including greedy and beam search, requires pytorch to trace if and for.
I very much doubt that transformers generate() will ever be jit scriptable. Therefore, the solution would probably be to rewrite generate() and greedy and beam search for them to be jit.scriptable. The other option is to use ONNX Runtime GreedySearch and BeamSearch ops, but those are usable only with CPUExecutionProvider and CUDAExecutionProvider (useless if you want to use TensorRT for example).
You could torch.jit.trace the model itself, and then torch.jit.script the generate method.
For generation, I think you will not be able to obtain a meaningful ScriptModule using torch.jit.trace. This is because the number of generation steps and early stopping depend on which tokens are generated, the number of loops is dynamic. torch.jit.trace would just hardcode the number of loops from the sample example provided during tracing, but this is not useful for other samples.
If we use torch.jit.trace function to convert it to TorchScript. How can we trace two different forward passes?
Normally you should just provide the model with a different set of inputs that trigger the controlflows you would like to use.