ONNX seems to handle this correctly!
import torch
import onnxruntime as ort
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x: torch.Tensor):
y = torch.zeros((1, 2, 4)).to(x.device)
return torch.cat((x, y), dim=1)
model = Model()
scripted = torch.jit.script(model)
real_x = torch.ones((1, 1, 4)).to("cuda")
fake_x = torch.empty((1, 0, 4)).to("cuda")
print("Output of model with a real input:")
print(model(real_x))
print(model(real_x).shape)
print("-" * 20)
print("Output of model with a fake input:")
print(model(fake_x))
print(model(fake_x).shape)
print("-" * 20)
model_path = "simple_model.onnx"
torch.onnx.export(
model,
real_x,
model_path,
input_names=["x"],
output_names=["y"],
opset_version=16,
dynamic_axes={
"x": {1: "sequence"},
"y": {1: "sequence"},
}
)
session = ort.InferenceSession(model_path, providers=["CUDAExecutionProvider"])
print("Output of ONNX with a real input:")
real_y = session.run(None, {"x": real_x.cpu().numpy()})[0]
print(real_y)
print(real_y.shape)
print("-" * 20)
print("Output of ONNX with a fake input:")
fake_y = session.run(None, {"x": fake_x.cpu().numpy()})[0]
print(fake_y)
print(fake_y.shape)
print("-" * 20)
Output of model with a real input:
tensor([[[1., 1., 1., 1.],
[0., 0., 0., 0.],
[0., 0., 0., 0.]]], device='cuda:0')
torch.Size([1, 3, 4])
--------------------
Output of model with a fake input:
tensor([[[0., 0., 0., 0.],
[0., 0., 0., 0.]]], device='cuda:0')
torch.Size([1, 2, 4])
--------------------
Output of ONNX with a real input:
[[[1. 1. 1. 1.]
[0. 0. 0. 0.]
[0. 0. 0. 0.]]]
(1, 3, 4)
--------------------
Output of ONNX with a fake input:
[[[0. 0. 0. 0.]
[0. 0. 0. 0.]]]
(1, 2, 4)
--------------------