Custom model export to onnx-runtime

I modified the BertEmbeddins, BertModel and BertForTokenClassification to accept additional feature (whether token in capitalized or not), in pure transformers it all works, but I am struggling with implementing the export of this custom model (so I can optimize it with optimum and get an inference speed up)

register_for_onnx = TasksManager.create_register("onnx")

@register_for_onnx("custom-bert", "token-classification")
class CustomOnnxConfig(TextEncoderOnnxConfig):
    # Specifies how to normalize the BertConfig, this is needed to access common attributes
    # during dummy input generation.
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
    # Sets the absolute tolerance to when validating the exported ONNX model against the
    # reference model.

    def inputs(self) -> Dict[str, Dict[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
            dynamic_axis = {0: "batch_size", 1: "sequence_length"}
        return {
            "input_ids": dynamic_axis,
            "attention_mask": dynamic_axis,
            "capitalization_ids": dynamic_axis,
            "token_type_ids": dynamic_axis,

base_model = CustomBertForTokenClassification.from_pretrained("my-checkpoint")

onnx_path = Path("model.onnx")

Here I do not understand what to do next
base_model.config returns BertConfig, which I think I need to overwrite with the custom config I created in the previous step.

Can you please help me?

1 Like

First, I can see that your new model will have a new input. We need a DummyInputGenerator that can handle this. So you could try something like:

from optimum.utils import DummyTextInputGenerator

class MyDummyTextInputenerator(DummyTextInputGenerator):
    SUPPORTED_INPUT_NAMES = (                                                                                                                                                               

class CustomOnnxConfig(TextEncoderOnnxConfig):
    # Specifies how to normalize the BertConfig, this is needed to access common attributes
    # during dummy input generation.
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
    DUMMY_INPUT_GENERATOR_CLASSES = (MyDummyTextInputenerator,)
    # Sets the absolute tolerance to when validating the exported ONNX model against the
    # reference model.

    def inputs(self) -> Dict[str, Dict[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
            dynamic_axis = {0: "batch_size", 1: "sequence_length"}
        return {
            "input_ids": dynamic_axis,
            "attention_mask": dynamic_axis,
            "capitalization_ids": dynamic_axis,
            "token_type_ids": dynamic_axis,

Since an OnnxConfig already exists for bert, the register method will not do anything.
If you want to be able to overwrite existing register configurations you can do that:

register_for_onnx = TasksManager.create_register("onnx", overwrite_existing=True)

Another approach is to do it programmatically: you can first export your model using Python code and then optimize it using the CLI.

from pathlib import Path
from optimum.exporters.onnx import export


Thank you a lot! I am sorry, I still have issues

from optimum.utils import DummyTextInputGenerator
from optimum.exporters.onnx.config import OnnxConfig

class MyDummyTextInputenerator(DummyTextInputGenerator):
    SUPPORTED_INPUT_NAMES = (                                                                                                                                                               
class TextEncoderOnnxConfig(OnnxConfig):
    # Describes how to generate the dummy inputs.
    DUMMY_INPUT_GENERATOR_CLASSES = (DummyTextInputGenerator,)

register_for_onnx = TasksManager.create_register("onnx", overwrite_existing=True)

@register_for_onnx("bert", "token-classification")
class CustomOnnxConfig(TextEncoderOnnxConfig):
    # Specifies how to normalize the BertConfig, this is needed to access common attributes
    # during dummy input generation.
    NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
    DUMMY_INPUT_GENERATOR_CLASSES = (MyDummyTextInputenerator,)
    # Sets the absolute tolerance to when validating the exported ONNX model against the
    # reference model.

    def inputs(self) -> Dict[str, Dict[int, str]]:
        if self.task == "multiple-choice":
            dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
            dynamic_axis = {0: "batch_size", 1: "sequence_length"}
        return {
            "input_ids": dynamic_axis,
            "attention_mask": dynamic_axis,
            "capitalization_ids": dynamic_axis,
            "token_type_ids": dynamic_axis,

model = CustomBertForTokenClassification.from_pretrained("my-checkpoint")

from pathlib import Path
from optimum.exporters.onnx import export


I get error RuntimeError: number of output names provided (1) exceeded number of outputs (0)

I think this might be linked to this issue.
Could you try installing optimum from sources and re-run your script?

Did installation like this:

pip install transformers[onnx]==4.28.1
python -m pip install git+

But the error stayed same RuntimeError: number of output names provided (1) exceeded number of outputs (0)

Hi Maiia,
I saw your message here, thanks for creating a notebook, it really helped me to try things out.

You do not need to register anything since you are not using the TasksManager in the end. Creating your OnnxConfig and using the export function is enough.

But because you are not using the TasksManager, you have to instantiate the OnnxConfig manually, and you need to specify the name of the task when doing so, otherwise it will infere the task to be the default one (BertModel and not BertForTokenClassification).

So to be able to export your model the fix is actually easy:

    CustomOnnxConfig(model.config, task="token-classification"),

Thank you a lot! You helped me so, so much!

Last question - if I need to reload it (as OrtModelForTokenClassification to optimize with ORTOptimizer) I probably need to save a modified config? (If a save original config and do

reloaded_model = ORTModelForTokenClassification.from_pretrained("onnx")
inputs = {k: torch.zeros([2, 16], dtype=torch.long) for k in ["input_ids", "attention_mask", "capitalization_ids", "token_type_ids"]}

I get an error about ValueError: Model requires 4 inputs. Input Feed contains 3

Or maybe I just need to modify OrtModelForTokenClassification class and it will fix the issue…

Modified OrtForTokenClassification - and everything works perfectly now!

Thank you so, so, so much for patience and kindness and spending time on a beginner!

And for wonderful work which Huggingface gives people! Thank you!!!

1 Like