Cannot export to ONNX with optimum.onnxruntime

Hello. I’m trying to export Upstage/SOLAR-10.7B-v1.0 model to ONNX with optimum.onnxruntime following the optimum guide: Export a model to ONNX with optimum.exporters.onnx
But I encountered this error:

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 8 for tensor number 1 in the list.

Could someone please guide how to resolve this issue?

During the process I also got this error multiple times. Could it be the cause of the error?

/home/sh/miniconda3/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py:94: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can’t record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:

Here is the code I used and full error:

from optimum.onnxruntime import ORTModelForCausalLM
from transformers import AutoTokenizer

model_checkpoint = "Upstage/SOLAR-10.7B-v1.0"
save_directory = "onnx/"

ort_model = ORTModelForCausalLM.from_pretrained(model_checkpoint, export=True)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

ort_model.save_pretrained(save_directory)
tokenizer.save_pretrained(save_directory)
$ python3 convert_to_onnx.py 
Framework not specified. Using pt to export to ONNX.
Loading checkpoint shards: 100%|███████████████████████████████████████████| 5/5 [00:05<00:00,  1.11s/it]
Using the export variant default. Available variants are:
        - default: The default ONNX variant.
use_past = False is different than use_present_in_outputs = True, the value of use_present_in_outputs value will be used for the outputs.
Using framework PyTorch: 2.0.1+cu118
Overriding 1 configuration item(s)
        - use_cache -> True
/home/sh/miniconda3/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py:94: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal:
/home/sh/miniconda3/lib/python3.11/site-packages/transformers/modeling_attn_mask_utils.py:137: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if past_key_values_length > 0:
/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:140: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if seq_len > self.max_seq_len_cached:
/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:392: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:399: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py:409: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Saving external data to one file...
Using framework PyTorch: 2.0.1+cu118
Overriding 1 configuration item(s)
        - use_cache -> True
Asked a sequence length of 16, but a sequence length of 1 will be used with use_past == True for `input_ids`.
============= Diagnostic Run torch.onnx.export version 2.0.1+cu118 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

Traceback (most recent call last):
  File "/home/sh/tensorrt-test/convert_to_onnx.py", line 8, in <module>
    ort_model = ORTModelForCausalLM.from_pretrained(model_checkpoint, export=True)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/onnxruntime/modeling_ort.py", line 647, in from_pretrained
    return super().from_pretrained(
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/modeling_base.py", line 372, in from_pretrained
    return from_pretrained_method(
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/onnxruntime/modeling_decoder.py", line 567, in _from_transformers
    main_export(
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/__main__.py", line 486, in main_export
    _, onnx_outputs = export_models(
                      ^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 752, in export_models
    export(
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 855, in export
    export_output = export_pytorch(
                    ^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/convert.py", line 572, in export_pytorch
    onnx_export(
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/onnx/utils.py", line 989, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/onnx/utils.py", line 893, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/jit/_trace.py", line 1268, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
                ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/optimum/exporters/onnx/model_patcher.py", line 113, in patched_forward
    outputs = self.orig_forward(*args, **kwargs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1034, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 922, in forward
    layer_outputs = decoder_layer(
                    ^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 672, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
                                                          ^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1488, in _slow_forward
    result = self.forward(*input, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/sh/miniconda3/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 382, in forward
    key_states = torch.cat([past_key_value[0], key_states], dim=2)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 32 but got size 8 for tensor number 1 in the list.