ModernBert unable to convert to onnx model

I try to convert ModernBert from pt to onnx, but get a compiler error from triton:

Traceback (most recent call last):
  File "/usr/local/bin/optimum-cli", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.10/dist-packages/optimum/commands/optimum_cli.py", line 208, in main
    service.run()
  File "/usr/local/lib/python3.10/dist-packages/optimum/commands/export/onnx.py", line 265, in run
    main_export(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/__main__.py", line 375, in main_export
    onnx_export_from_model(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 1176, in onnx_export_from_model
    _, onnx_outputs = export_models(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 762, in export_models
    export(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 867, in export
    export_output = export_pytorch(
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/convert.py", line 563, in export_pytorch
    onnx_export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/__init__.py", line 375, in export
    export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 502, in export
    _export(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1564, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 997, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/utils.py", line 904, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 1500, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 139, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.10/dist-packages/torch/jit/_trace.py", line 130, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/optimum/exporters/onnx/model_patcher.py", line 151, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 1160, in forward
    outputs = self.model(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 913, in forward
    layer_outputs = encoder_layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 529, in forward
    attn_outputs = self.attn(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 487, in forward
    attn_outputs = MODERNBERT_ATTENTION_FUNCTION[self.config._attn_implementation](
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 349, in flash_attention_forward
    qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1726, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 178, in forward
    qkv = apply_rotary_unpadded(
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 136, in apply_rotary_unpadded
    return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 575, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.10/dist-packages/transformers/models/modernbert/modeling_modernbert.py", line 75, in forward
    apply_rotary(
  File "/usr/local/lib/python3.10/dist-packages/flash_attn/ops/triton/rotary.py", line 202, in apply_rotary
    rotary_kernel[grid](
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 345, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/triton/runtime/jit.py", line 662, in run
    kernel = self.compile(
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 276, in compile
    module = src.make_ir(options, codegen_fns, context)
  File "/usr/local/lib/python3.10/dist-packages/triton/compiler/compiler.py", line 113, in make_ir
    return ast_to_ttir(self.fn, self, context=context, options=options, codegen_fns=codegen_fns)
triton.compiler.errors.CompilationError: at 32:22:
    # Meta-parameters
    BLOCK_K: tl.constexpr,
    IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,
    IS_VARLEN: tl.constexpr,
    INTERLEAVED: tl.constexpr,
    CONJUGATE: tl.constexpr,
    BLOCK_M: tl.constexpr,
):
    pid_m = tl.program_id(axis=0)
    pid_batch = tl.program_id(axis=1)
    pid_head = tl.program_id(axis=2)
    rotary_dim_half = rotary_dim // 2
                      ^
IncompatibleTypeErrorImpl('invalid operands of type pointer<int64> and triton.language.int32')

the command is

optimum-cli export onnx -m "./"  --task text-classification --device cuda model.onnx

and I use a local model from huggingface mordernBert unchanged.

optimum-cli is pip install from source, the branch is

other env:
python==3.10.12
torch==2.5.1
trtion==3.1.0
flash_attn-2.7.2.post1+cu12torch2.5cxx11abiFALSE-cp310-cp310-linux_x86_64.whl

1 Like

How about this answer?

1 Like

hello, I met the same error. Any progress? :slight_smile:

1 Like

Thanks for sharing. It helps me a lot.