Finetuning BART using custom loss

Hi @himanshu, the simplest way to implement custom loss functions is by subclassing the Trainer class and overriding the compute_loss function, e.g.

from transformers import Trainer

class BartTrainer(Trainer):
    def compute_loss(self, model, inputs):
        # implement custom logic here
        custom_loss = ...
        return custom_loss

You can find more details in the docs here: Trainer — transformers 4.3.0 documentation

10 Likes