Exporting imported BERT model to ONNX

I have the following model:

class BertClassifier(nn.Module):
    """
    Class defining the classifier model with a BERT encoder and a single fully connected classifier layer.
    """
    def __init__(self, dropout=0.5, num_labels=24):
        super(BertClassifier, self).__init__()

        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.Linear(768, num_labels)
        self.relu = nn.ReLU()
        self.best_score = 0

    def forward(self, input_id, mask):
        _, pooled_output = self.bert(input_ids=input_id, attention_mask=mask, return_dict=False)
        output = self.relu(self.linear(self.dropout(pooled_output)))

        return output

Using these inputs:

    ex_string = "example string"
    inputs = tokenizer(ex_string,
                       padding='max_length', max_length=512, truncation=True,
                       return_tensors="pt")
    input_id = inputs['input_ids'].squeeze(1)
    mask = inputs['attention_mask']

And I export the model to ONNX using:

        with torch.no_grad():
            input_names, output_names, dynamic_axes = infer_shapes(model, input_id, mask)
            torch.onnx.export(model=model,
                              args=(input_id, mask),
                              f='tryout.onnx',
                              input_names=input_names,
                              output_names=output_names,
                              dynamic_axes=dynamic_axes,
                              export_params=True,
                              do_constant_folding=False,
                              verbose=False)

Which results in the following stack trace:

/.local/lib/python3.9/site-packages/torch/onnx/utils.py:1294: UserWarning: Provided key input_ids for dynamic axes is not a valid input/output name
  warnings.warn("Provided key {} for dynamic axes is not a valid input/output name".format(key))
/.local/lib/python3.9/site-packages/torch/onnx/symbolic_helper.py:325: UserWarning: Type cannot be inferred, which might cause exported graph to produce incorrect results.
  warnings.warn("Type cannot be inferred, which might cause exported graph to produce incorrect results.")
[W shape_type_inference.cpp:434] Warning: Constant folding in symbolic shape inference fails: index_select(): Index is supposed to be a vector
Exception raised from index_select_out_cpu_ at ../aten/src/ATen/native/TensorAdvancedIndexing.cpp:887 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x42 (0x7fd75aaa7d62 in /.local/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: at::native::index_select_out_cpu_(at::Tensor const&, long, at::Tensor const&, at::Tensor&) + 0x3a9 (0x7fd79f9d4189 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #2: at::native::index_select_cpu_(at::Tensor const&, long, at::Tensor const&) + 0xe6 (0x7fd79f9d6146 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x1d37f12 (0x7fd7a00cdf12 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #4: at::_ops::index_select::redispatch(c10::DispatchKeySet, at::Tensor const&, long, at::Tensor const&) + 0xb9 (0x7fd79fc69099 in /home/floris/.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #5: <unknown function> + 0x3250ac3 (0x7fd7a15e6ac3 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #6: <unknown function> + 0x32510f5 (0x7fd7a15e70f5 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #7: at::_ops::index_select::call(at::Tensor const&, long, at::Tensor const&) + 0x166 (0x7fd79fce8ce6 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #8: torch::jit::onnx_constant_fold::runTorchBackendForOnnx(torch::jit::Node const*, std::vector<at::Tensor, std::allocator<at::Tensor> >&, int) + 0x1b5f (0x7fd821d4c6ff in /.local/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #9: <unknown function> + 0xbbdc22 (0x7fd821d93c22 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #10: torch::jit::ONNXShapeTypeInference(torch::jit::Node*, std::map<std::string, c10::IValue, std::less<std::string>, std::allocator<std::pair<std::string const, c10::IValue> > > const&, int) + 0xa8e (0x7fd821d9946e in /.local/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #11: <unknown function> + 0xbc4f74 (0x7fd821d9af74 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #12: <unknown function> + 0xb35730 (0x7fd821d0b730 in /.local/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #13: <unknown function> + 0x2a5d8b (0x7fd82147bd8b in /.local/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #14: python3() [0x53a8eb]
<omitting python frames>
frame #17: python3() [0x50f5e9]
frame #20: python3() [0x50f5e9]
frame #23: python3() [0x50f5e9]
frame #26: python3() [0x50f5e9]
frame #29: python3() [0x50f5e9]
frame #32: python3() [0x50f5e9]
frame #35: python3() [0x608ebb]
frame #36: python3() [0x603ea4]
frame #37: python3() [0x60834d]
frame #41: <unknown function> + 0x2dfd0 (0x7fd8aa12cfd0 in /lib/x86_64-linux-gnu/libc.so.6)
frame #42: __libc_start_main + 0x7d (0x7fd8aa12d07d in /lib/x86_64-linux-gnu/libc.so.6)
 (function ComputeConstantFolding)
Traceback (most recent call last):
  File "/bert_extraction/bert_onnx.py", line 89, in <module>
    torch.onnx.export(model=model,
  File "/.local/lib/python3.9/site-packages/torch/onnx/__init__.py", line 316, in export
    return utils.export(model, args, f, export_params, verbose, training,
  File "/.local/lib/python3.9/site-packages/torch/onnx/utils.py", line 107, in export
    _export(model, args, f, export_params, verbose, training, input_names, output_names,
  File "/.local/lib/python3.9/site-packages/torch/onnx/utils.py", line 724, in _export
    _model_to_graph(model, args, verbose, input_names,
  File "/.local/lib/python3.9/site-packages/torch/onnx/utils.py", line 544, in _model_to_graph
    params_dict = torch._C._jit_pass_onnx_constant_fold(graph, params_dict,
IndexError: index_select(): Index is supposed to be a vector

I can get the ONNX model to compile when I change the do_constant_folding flag to False, but obviously I don’t want to do that, as I’m trying to optimize the inference-time. Additionally, this results in the warning: ::FunctionImpl::FunctionImpl(onnxruntime::Graph&, const NodeIndex&, const onnx::FunctionProto&, const std::unordered_map<std::basic_string<char>, const onnx::FunctionProto*>&, std::vector<std::unique_ptr<onnxruntime::Function> >&, const onnxruntime::logging::Logger&, bool) status.IsOK() was false. Resolve subgraph failed:Node (0xad87190) Op (Flatten) [ShapeInferenceError] Invalid value(-1) for attribute 'axis' . Execution will fail if ORT does not have a specialized kernel for this op

Can anyone shed some light on the error or what I’m doing wrong? I know HuggingFace has a way to export BERT to onnx, so there must be some way to work around this from an imported BERT model, right?

1 Like