Passing output of BART to another model


I’m trying to finetune BART model using custom loss. What I need is to generate text with BART, pass this text as a part of the input to another model and then compute loss and backpropagate this whole system. The second model’s weights are frozen and it’s output is between 0 and 1.

In psedoucode I want something like this:

def compute_loss(self, model, inputs):
    first_output, first_loss = first_model.generate(inputs)
    text = decode(output)
    second_output, second_loss = second_model(text)
    loss = loss_function(second_output, targets)
    return loss

How can I do something like this as the decoding is not differentiable?

Thank you