Speeding up T5 inference 🚀

seq2seq decoding is inherently slow and using onnx is one obvious solution to speed it up. The onnxt5 package already provides one way to use onnx for t5.

But if we export the complete T5 model to onnx, then we can’t use the past_key_values for decoding since for the first decoding step past_key_values will be None and onnx doesn’t accept None input. Without past_key_values onnx won’t give any speed-up over torch for beam search.

One other solution is to export the encoder and lm_head to onnx and keep the decoder in torch, this way the decoder can use the past_key_values.

I’ve written a proof-of-concept script which does exactly this and also makes it compatible with the generate method

With this you can

enc = tokenizer("translate English to French: This is cool!", return_tensors="pt")
onnx_model = OnnxT5(model_name_or_path="t5-small", onnx_path="onnx_models")
tokens = onnx_model.generate(**enc, num_beams=2, use_cache=True) # same HF's generate method

In my experiments this gave ~1.4-1.6x speed-up with beam search.

The first time you call OnnxT5 it’ll load the model from the hub, export it to onnx as described above and save the exported graphs at onnx_path. So loading will be slower the first time.

Now to gain further speed-up we could distill the model and use less decoder layers.
onnx + distillation should give even more speed-up with minimal drop in accuracy.
@sshleifer has just published awesome seq2seq distillation paper which can be used to distill T5 model as well.

I’ll be sharing T5 distillation results soon!

now this is a very hacky solution, so feel free to suggest feedback and improvements or any other method that can help speed things up :slight_smile:

cc. @abel, @patrickvonplaten , @sshleifer


Here’s a self-contained notebook and small benchmark for summarization and translation tasks.

1 Like

A question on twitter

can’t we import T5Model from transformer and then may be pass it through a wrapper or decorator for onnx optimization. Or do we really need to change some ops in torch so that they can be onnx compatible?

We could export the T5 as it is but then we can’t use past_key_values because they will be None for first decoding step and onnx doesn’t accept None input. So here we are exporting individual components to be able to use past_key_values

Here you could also export the encoder and lm_head directly instead of creating classes for it. I did it this way to make it more readable. And we need to write the extra code in OnnxT5 class to make it compatible with generate.

Feel free to suggest changes if you think it can be improved :slight_smile:

1 Like

Thanks @valhalla . The reason of the doubt was, we have to make changes to some parts of the code to make it work as it is. And this sane issue is there in tensorflow too. Very nice post and super useful. Thanks for sharing:-) :+1:

What is causing the Onnx T5 model to be faster than the regular PyTorch T5 model?

Because of onnxruntime, which does several graph optimizations to speed-up inference. You could find more info in onnx runtime docs


you can quantize the model. As mentioned in the PyTorch doc

PyTorch supports INT8 quantization compared to typical FP32 models allowing for a 4x reduction in the model size and a 4x reduction in memory bandwidth requirements. Hardware support for INT8 computations is typically 2 to 4 times faster compared to FP32 compute.

this notebook shows the benchmarks of using quantized models with onnx.

1 Like

hey @valhalla, does this onnx runtime solution work with mT5 also?

might not work as it is, but should be easy to adapt. I think replacing T5 related classes with MT5 should make it work, for example, T5ForConditionalGeneration with MT5ForConditionalGeneration etc

Hey @valhalla

Hope you’re well. Firstly, I want to thank you for sharing this.

I was trying to run the notebook you had shared to recreate the results you reported but kept running into an error. Would you be able to help me fix it? The only change to the notebook I had made was installed the sentencepiece library as it was required by T5Tokenizer.

Also, did you get a chance to distill the T5 model and produce some results? Looking forward to those :slight_smile:


Hi @rigoh5 ,

what transformer version are you using ? could you try with 3.1.0, I haven’t tested it with the latest version.

hey, @valhalla

I’ve created fastT5 library. it exports T5 model to onnx with
past_key_values. the exported model supports generate() method.
for onnx models I was getting the inference speed boost of up to 5x for greedy and 3-4x for beam search. for more details check out the fastT5 repo.

currently, transformers library does not support exporting of t5 to onnx with past_key_values, you can fix this issue by following the guide in this notebook. created PR for this support here


This is awesome, great job! You should open a new thread to share this so people will know this.

I’ll take a look at the PR

1 Like

thank you! @valhalla. created a new thread here.

How can we do this with Pegasus.
I was trying to modify the script for Pegasus just by loading Pegasus model and i was able to load the encoder and decoder with:

encoder = model.model.encoder
decoder = model.model.decoder
lm_head = model.lm_head

but i get get error even when i try to build the original T5. It does generate the t5 onxx model, but i get errors when trying to use it.

Hi @kira,

Thanks for fastT5.

Do you plan to update it as the last version (0.1.2, 01/06/2022) is incompatible with transformers>=4.16?

Note: for people who want to use the last fastT5 version (0.1.2), install fastT5 from source by running in a notebook cell the following code:

!git clone https://github.com/Ki6an/fastT5
%cd fastT5

Download the file setup.py and change the line "transformers>4.6.1" by "transformers==4.15.0" as the transformers versions >= 4.16.0 do not work with the last fastT5 version (0.1.2).

Save and upload the file setup.py before to launch the following cell:
pip3 install -e .

Et voilà.


if you are using PegasusForConditionalGeneration

model = PegasusForConditionalGeneration.from_pretrained(‘model-path-from-huggingface’)
encoder = model.get_encoder
decoder = model.get_decoder
lm_head = model.lm_head

if you are using PegasusModel class from transformers

model = PegasusModel.from_pretrained(‘model-path-from-huggingface’)
encoder = model.encoder
decoder = model.decoder

but you can’t get model.lm_head because it’s not part of PegasusModel.