I want o fine tune BART using custom loss. What I want to do is take the output text generated by the BART model, feed it to a classifier and update weights of the BART model using the classification loss. Please note that I do not want to train the classifier, rather I want to train the BART model using the classification loss on the generated text. Can someone give me pointers on how to do it?
Hi, this may be a bit naive, but do you then pass on the parameters to BartTrainers as you would to Trainer, as such;
trainer = BartTrainer(
model=model, # the instantiated Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=train_dataset, # training dataset
eval_dataset=val_dataset, # evaluation dataset
compute_metrics=compute_metrics,
)
Hi @lewtun thanks for the reply. Actually, I am looking for how to compute the loss i.e. how to “implement custom logic”, what I mean by that is the output of BART is some text, I want to feed that text to the classifier and update weights of BART based on classification loss. How do I code such a system. I hope that make sense .
TIA
Hi @himanshu, could you provide a code snippet of what you have in mind (even if it’s pseudo-code)? For example, can you show what the inputs / outputs to the classifiers should look like and how the classification loss is computed?
def compute_loss(self,model,inputs,classifier):
output = model.generate(inputs)
text = tok.decode(output)
# convert text to ids
classifier_output = classifier(text)
loss = loss_function(classifier_output, targets)
return loss
Now, when I do loss.backward(), I want to update the weights of model and not classifier. I understand that there is non-differentiability involved as soon as I use the model to generate text, I want to know how to handle that?
I did it to use pytorch lightening
and change this part
outputs = self(batch[0], batch[1], batch[2])
loss = outputs["loss"]
y_hat, y = outputs["logits"].view(-1, outputs["logits"].shape[-1]), batch[2].view(-1)
cross_entropy_loss = F.cross_entropy(y_hat, y, reduction="mean") # you can change it to your own custom loss