Converting AlignTTS (text-to-speech) model to ONNX

I’m currently trying to convert a PyTorch model to ONNX using the torch.onnx.export function, but I’m running into some issues. Specifically, I’m getting an error message when trying to convert my model using the following code:

def converting_onnx(onnx_path):
    
    # create an instance of AlignTTSArgs with desired values
    args = AlignTTSArgs(num_chars=105, out_channels=80, hidden_channels=256, hidden_channels_dp=256, 
                        encoder_type='fftransformer', 
                        encoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1}, 
                        # Added "use_causal_conv": True for Causal Convolution
                        decoder_type='fftransformer', 
                        decoder_params={"hidden_channels_ffn": 1024, "num_heads": 2, "num_layers": 6, "dropout_p": 0.1},
                        length_scale=1.0, num_speakers=0, use_speaker_embedding=False, use_d_vector_file=False, d_vector_dim=0)

    # create an instance of AlignTTS
    model = AlignTTS(config=args)
    
    # Load the saved checkpoint
    checkpoint = torch.load(onnx_path, map_location=torch.device('cpu'))
    # Load the state dict into the model
    model.load_state_dict(checkpoint['model'])
    
    print(model)
    
    # Set the model to evaluation mode
    model.eval()

    batch_size = 8
    in_dim = 105
    seq_len = 256
    
    # Create dummy inputs
    dummy_input_x1 = torch.zeros(batch_size, seq_len).to(dtype=torch.long) 
    dummy_input_x2 = torch.zeros(batch_size, seq_len).to(dtype=torch.long) 
    dummy_input_x3 = torch.zeros(batch_size, seq_len).to(dtype=torch.long)  
    y_lengths = torch.zeros(batch_size).to(dtype=torch.long)  
    
    print('Dummy input X1:', dummy_input_x1)
    print('Dummy input X2:', dummy_input_x2)
    print('Dummy input X3:', dummy_input_x3)
    print('Y lengths:', y_lengths)
    
    # Increase the recursion limit
    sys.setrecursionlimit(5000) 
    
    # Trace the model
    traced_model = torch.jit.trace(model, (dummy_input_x1, dummy_input_x2, dummy_input_x3, y_lengths), strict=False)
  
    print('Traced model', traced_model)
    
    output_path = 'traced_model.onnx'  # Specify the file path
    
    # Export the model to ONNX
    torch.onnx.export(traced_model, (dummy_input_x1, dummy_input_x2, dummy_input_x3, y_lengths), output_path,
                      verbose=True, export_params=True, opset_version=10,
                      input_names=['x1', 'x2', 'x3', 'y_lengths'], output_names=['output1'], dynamic_axes={'x1': {0: 'batch_size', 1: 'seq_len'},
                                                                                                            'x2': {0: 'batch_size', 1: 'seq_len'},
                                                                                                            'x3': {0: 'batch_size', 1: 'seq_len'},
                                                                                                            'y_lengths': {0: 'batch_size'}})
    print("PyTorch model has been successfully exported to ONNX")
    

if __name__ == '__main__':
    # Call the function with the checkpoint path
    checkpoint_path = '/home/elias/male_checkpoint.pth'
    converting_onnx(checkpoint_path)

The error message I’m getting is:

Traceback (most recent call last):
  File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2243, in embedding
    return F.embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
  File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2243, in embedding
    return F.embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
  File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2243, in embedding
    return F.embedding(input, weight, padding_idx, scale_grad_by_freq, sparse)
  [Previous line repeated 995 more times]
  File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/nn/functional.py", line 2231, in embedding
    assert padding_idx < num_embeddings, "Padding_idx must be within num_embeddings"
  File "/home/elias/miniconda3/envs/dev-elias/lib/python3.10/site-packages/torch/jit/_trace.py", line 55, in _get_interpreter_name_for_var
    elif isinstance(v, torch.jit.ScriptModule) or isinstance(v, torch.jit.ScriptFunction):
RecursionError: maximum recursion depth exceeded while calling a Python object ```

I would appreciate any help or guidance on how to fix this issue. I've already tried increasing the recursion limit, but that didn't seem to solve the problem. Also, I've verified that the PyTorch model is loading correctly and can be traced successfully.

Here are some additional details on my environment:

I'm using Python 3.10.9
PyTorch version: 2.0.0
ONNX version: 1.13.1

Please let me know if you need any further information to help diagnose the issue. Thank you!