seq2seq decoding is inherently slow and using
onnx is one obvious solution to speed it up. The onnxt5 package already provides one way to use
But if we export the complete T5 model to
onnx, then we can’t use the
past_key_values for decoding since for the first decoding step
past_key_values will be
onnx doesn’t accept
None input. Without
onnx won’t give any speed-up over torch for beam search.
One other solution is to export the
onnx and keep the
torch, this way the
decoder can use the
I’ve written a proof-of-concept script which does exactly this and also makes it compatible with the
With this you can
enc = tokenizer("translate English to French: This is cool!", return_tensors="pt") onnx_model = OnnxT5(model_name_or_path="t5-small", onnx_path="onnx_models") tokens = onnx_model.generate(**enc, num_beams=2, use_cache=True) # same HF's generate method tokenizer.batch_decode(tokens)
In my experiments this gave ~1.4-1.6x speed-up with beam search.
The first time you call
OnnxT5 it’ll load the model from the hub, export it to
onnx as described above and save the exported graphs at
onnx_path. So loading will be slower the first time.
Now to gain further speed-up we could distill the model and use less decoder layers.
onnx + distillation should give even more speed-up with minimal drop in accuracy.
@sshleifer has just published awesome seq2seq distillation paper which can be used to distill
T5 model as well.
I’ll be sharing T5 distillation results soon!
now this is a very hacky solution, so feel free to suggest feedback and improvements or any other method that can help speed things up