BART-MNLI performance optimization

Hi, everyone!

I need to classify texts of 100-words length on average into 1.5k classes in zero-shot setting.
To solve this task I am using facebook/bart-large-mnli model.

My setup is 32 CPU, 250 RAM.

On the first two pictures below you can see memory consumption during model inference. On both pics I categorize only 4 texts. As you can see time and memory consumption grow with text length. It seems to me, corresponding complexities are O(n) and O(log n), where n is the number of words in text.

I wonder if such big resource consumption is expected? On average to classify just one 90-word text you need to have more than 50Gb RAM and it takes more than 3 minutes. On 32 CPUs!

I wish I could use my NVIDIA P106 with 6Gb memory to speed up the inference. But this process doesn’t fit into its memory. Also, you can see that I tried valhalla/distilbart-mnli-12-1. Evidently, I can’t use it with my GPU either. Are there any workarounds? And what is the underlying mechanism of such consumption? It it attention?

Also, I tried to convert these MNLI models into ONNX to speed up inference. In that I miserably failed. I found those few tutorials I could understand and implemented the steps suggested. The problem is that all resulted onnx-models don’t have all necessary inputs. It has input_ids and attention_mask which, to my knowledge, responsible for text I want to categorize, but misses decoder_input_ids and decoder_attention_mask which, to my knowledge, responsible for candidate labels. I tried transformers.onnx and onnx_transformers by @valhalla unsuccessfully. Though latter, after resolving many package conflicts, ran with roberta-mnli. The problem is that it ran only on one CPU out of 32 and I didn’t find way to enable the others. Is there any chance somebody here could guide me how properly transform bart-mnli into onnx to make inference? Will it help if I pay for subscription to get such help?

Thank you.



No ideas at all? :smiling_face_with_tear:

Did you find a solution to your problem? I will be interested to see how you solve it.

Also looking for solutions to increase the speed of summarization using the BART model, did you find a solution?