I want to train two models where the output of one is the input of the other. I want the backpropagation after the output of the second model to be able to train the first model. My approach is to use Gumbel softmax to get an approximate one-hot vector. The problem is that the input of the pre-trained model is long type input_ids, is there a way to connect the two models together?
output = model1(**inputs) logits = output one_hot = gumbel_softmax(logits, tau=1, hard=True) # then how to deal with the one_hot vectors and make them into differentiable inputs to the second model