I’m currently trying to convert a PyTorch model to ONNX using the torch.onnx.export
function, but I’m running into some issues. Specifically, I’m getting an error message when trying to convert my model using the following code:
def converting_onnx(onnx_path):
# create an instance of AlignTTSArgs with desired values
args = AlignTTSArgs(num_chars=105, out_channels=80, hidden_channels=256, hidden_channels_dp=256,
encoder_type='fftransformer',
encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
# Added "use_causal_conv": True for Causal Convolution
decoder_type='fftransformer',
decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
length_scale=1.0, num_speakers=0, use_speaker_embedding=False, use_d_vector_file=False, d_vector_dim=0)
# create an instance of AlignTTS
model = AlignTTS(config=args)
# Load the saved checkpoint
checkpoint = torch.load(onnx_path, map_location=torch.device('cpu'))
# Load the state dict into the model
model.load_state_dict(checkpoint['model'])
print(model)
# Set the model to evaluation mode
model.eval()
batch_size = 8
in_dim = 105
seq_len = 256
# Create dummy inputs
dummy_input_x1 = torch.zeros(batch_size, seq_len).to(dtype=torch.long)
dummy_input_x2 = torch.zeros(batch_size, seq_len).to(dtype=torch.long)
dummy_input_x3 = torch.zeros(batch_size, seq_len).to(dtype=torch.long)
y_lengths = torch.zeros(batch_size).to(dtype=torch.long)
print('Dummy input X1:', dummy_input_x1)
print('Dummy input X2:', dummy_input_x2)
print('Dummy input X3:', dummy_input_x3)
print('Y lengths:', y_lengths)
# Increase the recursion limit
sys.setrecursionlimit(5000)
# Trace the model
traced_model = torch.jit.trace(model, (dummy_input_x1, dummy_input_x2, dummy_input_x3, y_lengths), strict=False)
print('Traced model', traced_model)
output_path = 'traced_model.onnx' # Specify the file path
# Export the model to ONNX
torch.onnx.export(traced_model, (dummy_input_x1, dummy_input_x2, dummy_input_x3, y_lengths), output_path,
verbose=True, export_params=True, opset_version=10,
input_names=['x1', 'x2', 'x3', 'y_lengths'], output_names=['output1'], dynamic_axes={'x1': {0: 'batch_size', 1: 'seq_len'},
'x2': {0: 'batch_size', 1: 'seq_len'},
'x3': {0: 'batch_size', 1: 'seq_len'},
'y_lengths': {0: 'batch_size'}})
print("PyTorch model has been successfully exported to ONNX")
if __name__ == '__main__':
# Call the function with the checkpoint path
checkpoint_path = '/home/elias/male_checkpoint.pth'
converting_onnx(checkpoint_path)
The error message I’m getting is:
Traceback (most recent call last):
File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2243, in embedding
return F.embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2243, in embedding
return F.embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2243, in embedding
return F.embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
[Previous line repeated 995 more times]
File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2231, in embedding
assert padding_idx < num_embeddings, "Padding_idx must be within num_embeddings"
File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/jit/_trace.py", line 55, in _get_interpreter_name_for_var
elif isinstance(v, torch.jit.ScriptModule) or isinstance(v, torch.jit.ScriptFunction):
RecursionError: maximum recursion depth exceeded while calling a Python object ```
I would appreciate any help or guidance on how to fix this issue. I've already tried increasing the recursion limit, but that didn't seem to solve the problem. Also, I've verified that the PyTorch model is loading correctly and can be traced successfully.
Here are some additional details on my environment:
I'm using Python 3.10.9
PyTorch version: 2.0.0
ONNX version: 1.13.1
Please let me know if you need any further information to help diagnose the issue. Thank you!