I modified the BertEmbeddins, BertModel and BertForTokenClassification to accept additional feature (whether token in capitalized or not), in pure transformers it all works, but I am struggling with implementing the export of this custom model (so I can optimize it with optimum and get an inference speed up)
register_for_onnx = TasksManager.create_register("onnx")
@register_for_onnx("custom-bert", "token-classification")
class CustomOnnxConfig(TextEncoderOnnxConfig):
# Specifies how to normalize the BertConfig, this is needed to access common attributes
# during dummy input generation.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
# Sets the absolute tolerance to when validating the exported ONNX model against the
# reference model.
ATOL_FOR_VALIDATION = 1e-4
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
else:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
"capitalization_ids": dynamic_axis,
"token_type_ids": dynamic_axis,
}
base_model = CustomBertForTokenClassification.from_pretrained("my-checkpoint")
onnx_path = Path("model.onnx")
Here I do not understand what to do next base_model.config returns BertConfig, which I think I need to overwrite with the custom config I created in the previous step.
First, I can see that your new model will have a new input. We need a DummyInputGenerator that can handle this. So you could try something like:
from optimum.utils import DummyTextInputGenerator
class MyDummyTextInputenerator(DummyTextInputGenerator):
SUPPORTED_INPUT_NAMES = (
"input_ids",
"attention_mask",
"token_type_ids",
"capitalization_ids"
)
class CustomOnnxConfig(TextEncoderOnnxConfig):
# Specifies how to normalize the BertConfig, this is needed to access common attributes
# during dummy input generation.
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
DUMMY_INPUT_GENERATOR_CLASSES = (MyDummyTextInputenerator,)
# Sets the absolute tolerance to when validating the exported ONNX model against the
# reference model.
ATOL_FOR_VALIDATION = 1e-4
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch_size", 1: "num_choices", 2: "sequence_length"}
else:
dynamic_axis = {0: "batch_size", 1: "sequence_length"}
return {
"input_ids": dynamic_axis,
"attention_mask": dynamic_axis,
"capitalization_ids": dynamic_axis,
"token_type_ids": dynamic_axis,
}
Since an OnnxConfig already exists for bert, the register method will not do anything.
If you want to be able to overwrite existing register configurations you can do that:
Hi Maiia,
I saw your message here, thanks for creating a notebook, it really helped me to try things out.
You do not need to register anything since you are not using the TasksManager in the end. Creating your OnnxConfig and using the export function is enough.
But because you are not using the TasksManager, you have to instantiate the OnnxConfig manually, and you need to specify the name of the task when doing so, otherwise it will infere the task to be the default one (BertModel and not BertForTokenClassification).
So to be able to export your model the fix is actually easy:
Last question - if I need to reload it (as OrtModelForTokenClassification to optimize with ORTOptimizer) I probably need to save a modified config? (If a save original config and do
reloaded_model = ORTModelForTokenClassification.from_pretrained("onnx")
inputs = {k: torch.zeros([2, 16], dtype=torch.long) for k in ["input_ids", "attention_mask", "capitalization_ids", "token_type_ids"]}
reloaded_model(**inputs)
I get an error about ValueError: Model requires 4 inputs. Input Feed contains 3
Or maybe I just need to modify OrtModelForTokenClassification class and it will fix the issue…