How does the ONNX exporter work for GenerationModel with `past_key_value`?

Hi @fxmarty

To use Optimum, I need to export my decoder-based generation model to ONNX format. However, the forward function contains a mems argument, only available at the second pass. It’s similar to past_key_value. You can find the code here.

Optimum seems to have pretty good support for various decoder models. I wonder how the conversion worked. And what should I do to convert my customized decoder model?

As far as I know, there is a way to export an encoder-decoder model BART. But the decoder is traced twice. Thus in ONNX graph, there will be two decoders, which seems wasting GPU memory: transformers/generation_onnx.py at 197e7ce911d91d85eb2f91858720957c2d979cd2 · huggingface/transformers (github.com)

Hi @in-certo . First off it is unfortunate GLM is not supported natively in transformers.

About the duplication of memory, you are very right. There is some ongoing work to avoid it: https://github.com/huggingface/optimum/pull/647 & https://github.com/huggingface/optimum/pull/647

The idea is to insert an If node in the ONNX graph to dispatch on two branches depending on whether it is the first pass in the decoder or not (in which case past key values are reused).

I wonder how the conversion worked. And what should I do to convert my customized decoder model?

It is a fair question - I believe you could fork Optimum to handle the export for your model. It could require some work though if the architecture of GLM lib is quite far off from transformers.

I believe we should make it easier as well for people to extend Optimum to handle the export of their custom models, I keep it in mind.

Some good references:
Adding support for an unsupported architecture ,
https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/base.py ,
https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/config.py ,
https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/model_configs.py (refer for example to GPT2OnnxConfig)
The script handling the actual export: https://github.com/huggingface/optimum/blob/main/optimum/exporters/onnx/__main__.py and its args

1 Like

I am so very grateful for your time. The references are of great help. I am looking forward to the merged decoder feature :slight_smile:

Hi @fxmarty ,

Can we use a torch.empty((a, b, c)) as a dummy input? If it’s possible, there is no need to create two separate graphs. Do I miss anything here?

Sorry I missed your messages!

The thing is that you need to go into some controlflows only if a past key value has already been computed. If you have not, you must not go into it.

And doing something like:

a = torch.empty(5, 10)
b = torch.rand(5, 1)
res = torch.cat([a, b], dim=1)

unfortunately does not help.

1 Like

Thank you for the time you’ve taken to support me as always.

The implication seems to be “We cannot modify the source code of the model implementation, so we have to go into different branches.”

But if we remove the if-condtion, and pass in an empty tensor. Won’t the graph be exactly the same?

Oh I see what you mean - using something like torch.empty(1, 0, 4, 16)?

I haven’t tested, but it could be that a single graph is enough then:

a = torch.empty(1, 0, 10, 14)
b = torch.rand(1, 12, 10, 14)
res = torch.cat([a, b], dim=1)

Not sure if torch.onnx.export would treat that right, but it could work!

But indeed - in Optimum we inherit from the code of transformers that in some cases can not be modified due to code legacy, unfortunately.

By the way, the merge decoder without/with past PR has been merged: https://github.com/huggingface/optimum/pull/647
There are a few remaining things to do: Merged ONNX decoder next steps · Issue #784 · huggingface/optimum · GitHub

1 Like

ONNX seems to handle this correctly!

import torch
import onnxruntime as ort


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor):
        y = torch.zeros((1, 2, 4)).to(x.device)
        return torch.cat((x, y), dim=1)

model = Model()
scripted = torch.jit.script(model)

real_x = torch.ones((1, 1, 4)).to("cuda")
fake_x = torch.empty((1, 0, 4)).to("cuda")

print("Output of model with a real input:")
print(model(real_x))
print(model(real_x).shape)
print("-" * 20)

print("Output of model with a fake input:")
print(model(fake_x))
print(model(fake_x).shape)
print("-" * 20)

model_path = "simple_model.onnx"

torch.onnx.export(
    model,
    real_x,
    model_path,
    input_names=["x"],
    output_names=["y"],
    opset_version=16,
    dynamic_axes={
        "x": {1: "sequence"},
        "y": {1: "sequence"},
    }
)

session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])


print("Output of ONNX with a real input:")
real_y = session.run(None, {"x": real_x.cpu().numpy()})[0]
print(real_y)
print(real_y.shape)
print("-" * 20)

print("Output of ONNX with a fake input:")
fake_y = session.run(None, {"x": fake_x.cpu().numpy()})[0]
print(fake_y)
print(fake_y.shape)
print("-" * 20)
Output of model with a real input:
tensor([[[1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], device='cuda:0')
torch.Size([1, 3, 4])
--------------------
Output of model with a fake input:
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.]]], device='cuda:0')
torch.Size([1, 2, 4])
--------------------
Output of ONNX with a real input:
[[[1. 1. 1. 1.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
(1, 3, 4)
--------------------
Output of ONNX with a fake input:
[[[0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
(1, 2, 4)
--------------------
2 Likes

That’s very clean! So if you are to reimplement the models, it is probably a good idea to go this way to simplify your life and remove unnecessary controlflows.

Hi @fxmarty

It does work for a small model!

But when exporting a larger one, I got a CUDA OOM error. Would you provide some insights? Your suggestions have always been helpful!

I opened a new thread here: CUDA OOM when export a large model to ONNX - :hugs:Optimum - Hugging Face Forums