seq2seq decoding is inherently slow and using onnx
is one obvious solution to speed it up. The onnxt5 package already provides one way to use onnx
for t5
.
But if we export the complete T5 model to onnx
, then we can’t use the past_key_values
for decoding since for the first decoding step past_key_values
will be None
and onnx
doesn’t accept None
input. Without past_key_values
onnx
won’t give any speed-up over torch for beam search.
One other solution is to export the encoder
and lm_head
to onnx
and keep the decoder
in torch
, this way the decoder
can use the past_key_values
.
I’ve written a proof-of-concept script which does exactly this and also makes it compatible with the generate
method
With this you can
enc = tokenizer("translate English to French: This is cool!", return_tensors="pt")
onnx_model = OnnxT5(model_name_or_path="t5-small", onnx_path="onnx_models")
tokens = onnx_model.generate(**enc, num_beams=2, use_cache=True) # same HF's generate method
tokenizer.batch_decode(tokens)
In my experiments this gave ~1.4-1.6x speed-up with beam search.
The first time you call OnnxT5
it’ll load the model from the hub, export it to onnx
as described above and save the exported graphs at onnx_path
. So loading will be slower the first time.
Now to gain further speed-up we could distill the model and use less decoder layers.
onnx + distillation should give even more speed-up with minimal drop in accuracy.
@sshleifer has just published awesome seq2seq distillation paper which can be used to distill T5
model as well.
I’ll be sharing T5 distillation results soon!
now this is a very hacky solution, so feel free to suggest feedback and improvements or any other method that can help speed things up
cc. @abel, @patrickvonplaten , @sshleifer