Finetuning BART using custom loss

Hi everyone,

I want o fine tune BART using custom loss. What I want to do is take the output text generated by the BART model, feed it to a classifier and update weights of the BART model using the classification loss. Please note that I do not want to train the classifier, rather I want to train the BART model using the classification loss on the generated text. Can someone give me pointers on how to do it?

TIA

1 Like

Hi @himanshu, the simplest way to implement custom loss functions is by subclassing the Trainer class and overriding the compute_loss function, e.g.

from transformers import Trainer

class BartTrainer(Trainer):
    def compute_loss(self, model, inputs):
        # implement custom logic here
        custom_loss = ...
        return custom_loss

You can find more details in the docs here: Trainer — transformers 4.3.0 documentation

10 Likes

Hi, this may be a bit naive, but do you then pass on the parameters to BartTrainers as you would to Trainer, as such;

trainer = BartTrainer(
    model=model,                         # the instantiated Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=val_dataset,             # evaluation dataset
   compute_metrics=compute_metrics,
)
1 Like

yep, that’s it - by subclassing the Trainer class, BartTrainer inherits all the attributes and functions :slightly_smiling_face:

strictly speaking, you might need to include super() in the subclass e.g.

class BartTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs):
        # implement custom logic here
        custom_loss = ...
        return custom_loss

there’s a nice template of a custom trainer for question-answering here that you could work from: transformers/trainer_qa.py at master · huggingface/transformers · GitHub

1 Like

Hi @lewtun thanks for the reply. Actually, I am looking for how to compute the loss i.e. how to “implement custom logic”, what I mean by that is the output of BART is some text, I want to feed that text to the classifier and update weights of BART based on classification loss. How do I code such a system. I hope that make sense :slight_smile: .
TIA

Hi @himanshu, could you provide a code snippet of what you have in mind (even if it’s pseudo-code)? For example, can you show what the inputs / outputs to the classifiers should look like and how the classification loss is computed?

Hi @lewtun, I have something like this in mind.

def compute_loss(self,model,inputs,classifier):
    output = model.generate(inputs)
    text = tok.decode(output)
    # convert text to ids
    classifier_output = classifier(text)
    loss = loss_function(classifier_output, targets)
    return loss

Now, when I do loss.backward(), I want to update the weights of model and not classifier. I understand that there is non-differentiability involved as soon as I use the model to generate text, I want to know how to handle that?

@lewtun @himanshu Is there any idea about that?

@himanshu @minji Did any of you make it work?

I did it to use pytorch lightening
and change this part

outputs = self(batch[0], batch[1], batch[2])
loss = outputs["loss"]
y_hat, y = outputs["logits"].view(-1, outputs["logits"].shape[-1]), batch[2].view(-1)
cross_entropy_loss = F.cross_entropy(y_hat, y, reduction="mean") # you can change it to your own custom loss