How does the ONNX exporter work for GenerationModel with `past_key_value`?

ONNX seems to handle this correctly!

import torch
import onnxruntime as ort


class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor):
        y = torch.zeros((1, 2, 4)).to(x.device)
        return torch.cat((x, y), dim=1)

model = Model()
scripted = torch.jit.script(model)

real_x = torch.ones((1, 1, 4)).to("cuda")
fake_x = torch.empty((1, 0, 4)).to("cuda")

print("Output of model with a real input:")
print(model(real_x))
print(model(real_x).shape)
print("-" * 20)

print("Output of model with a fake input:")
print(model(fake_x))
print(model(fake_x).shape)
print("-" * 20)

model_path = "simple_model.onnx"

torch.onnx.export(
    model,
    real_x,
    model_path,
    input_names=["x"],
    output_names=["y"],
    opset_version=16,
    dynamic_axes={
        "x": {1: "sequence"},
        "y": {1: "sequence"},
    }
)

session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])


print("Output of ONNX with a real input:")
real_y = session.run(None, {"x": real_x.cpu().numpy()})[0]
print(real_y)
print(real_y.shape)
print("-" * 20)

print("Output of ONNX with a fake input:")
fake_y = session.run(None, {"x": fake_x.cpu().numpy()})[0]
print(fake_y)
print(fake_y.shape)
print("-" * 20)
Output of model with a real input:
tensor([[[1., 1., 1., 1.],
         [0., 0., 0., 0.],
         [0., 0., 0., 0.]]], device='cuda:0')
torch.Size([1, 3, 4])
--------------------
Output of model with a fake input:
tensor([[[0., 0., 0., 0.],
         [0., 0., 0., 0.]]], device='cuda:0')
torch.Size([1, 2, 4])
--------------------
Output of ONNX with a real input:
[[[1. 1. 1. 1.]
  [0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
(1, 3, 4)
--------------------
Output of ONNX with a fake input:
[[[0. 0. 0. 0.]
  [0. 0. 0. 0.]]]
(1, 2, 4)
--------------------
2 Likes