I’m trying to finetune BART model using custom loss. What I need is to generate text with BART, pass this text as a part of the input to another model and then compute loss and backpropagate this whole system. The second model’s weights are frozen and it’s output is between 0 and 1.
In psedoucode I want something like this:
def compute_loss(self, model, inputs): first_output, first_loss = first_model.generate(inputs) text = decode(output) second_output, second_loss = second_model(text) loss = loss_function(second_output, targets) return loss
How can I do something like this as the decoding is not differentiable?