Hello,
I’m trying to finetune BART model using custom loss. What I need is to generate text with BART, pass this text as a part of the input to another model and then compute loss and backpropagate this whole system. The second model’s weights are frozen and it’s output is between 0 and 1.
In psedoucode I want something like this:
def compute_loss(self, model, inputs):
first_output, first_loss = first_model.generate(inputs)
text = decode(output)
second_output, second_loss = second_model(text)
loss = loss_function(second_output, targets)
return loss
How can I do something like this as the decoding is not differentiable?
Thank you