Custom Training Loss Function for Seq2Seq BART

How can we add a custom nn.CrossEntropyLoss() in Seq2SeqTrainer. I saw a the documentation where we can add custom loss function in Trainer(Trainer). Can we do same for Seq2SeqTrainer.

from torch import nn
from transformers import Trainer
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels")
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits")
        # compute custom loss (suppose one has 3 labels with different weights)
        loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0]))
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

Do this custom loss function then only takes the nn.CrossEntropyLoss()? My goal is to add the custom loss on top of the default loss.

2 Likes

Hey @Hiteshwar! Do you know the answer to this? I am having the same question now!