Differentiable Softmax and Argmax Problem

I want to train two models where the output of one is the input of the other. I want the backpropagation after the output of the second model to be able to train the first model. My approach is to use Gumbel softmax to get an approximate one-hot vector. The problem is that the input of the pre-trained model is long type input_ids, is there a way to connect the two models together?

output = model1(**inputs)
logits = output[1]
one_hot = gumbel_softmax(logits, tau=1, hard=True)
# then how to deal with the one_hot vectors and make them into differentiable inputs to the second model