Boost inference speed of T5 models up to 5X & reduce the model size by 3X

T5 models inference is naturally slow, as they undergo seq2seq decoding. To speed up the inference speed, we can convert the t5 model to onnx and run them on onnxruntime.

these are the steps to run T5 models on onnxruntime:

  • export t5 to onnx with past_key_values

past_key_values contain pre-computed hidden-states (key and values in the self-attention blocks and cross-attention blocks) that can be used to speed up sequential decoding.

  • quantize the model. (optional) quantizing reduces the model size & further increases the speed.
  • run these models on onnxruntime.
  • exported onnx or quantized onnx model should support greedy search and beam search.

as you can see the whole process looks complicated, I’ve created the fastT5 library to make it simple. all these above steps can be done in a single line of code using the fastT5 library.

pip install fastt5

from fastT5 import export_and_get_onnx_model
model = export_and_get_onnx_model('t5-small')

the model also supports generate() method

model.generate(input_ids=token['input_ids'],
               attention_mask=token['attention_mask'],
               num_beams=2)

for more info check out the repo.

NOTE
currently, the transformers library does not support exporting of t5 to onnx with past_key_values, you can fix this issue by following the guide in this notebook. created PR for this support here

5 Likes

So this solution only works with onnxruntime, and does not work with TorchServe?

Hello @kira,
I am working on speeding up a finetuned t5-mini batch cpu inference.

On the batch size = 10, sequence length = 300 tokens:

  • t5-mini inference speed: 3 sec
  • t5-mini after pytorch built-in dynamic quantization: 2.3 sec
  • fastT5 after converting to onnx and quantization: 5.9 sec !!

Maybe I am doing something wrong, but after fastT5 it was supposed to be faster right?

Collab notebook link:

Please let me know your thoughts.

1 Like