Finetuning BART using custom loss

yep, that’s it - by subclassing the Trainer class, BartTrainer inherits all the attributes and functions :slightly_smiling_face:

strictly speaking, you might need to include super() in the subclass e.g.

class BartTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def compute_loss(self, model, inputs):
        # implement custom logic here
        custom_loss = ...
        return custom_loss

there’s a nice template of a custom trainer for question-answering here that you could work from: transformers/trainer_qa.py at master · huggingface/transformers · GitHub

1 Like