Class weights for bertForSequenceClassification

Hi there,

I’m assuming you’re using the standard BertForSequenceClassification. So, instead of doing

outputs = model(**inputs)
loss = outputs['loss']

you do

outputs = model(**inputs)
logits = outputs['logits']
criterion = torch.nn.CrossEntropyLoss(weights=class_weights)
loss = criterion(logits, inputs['labels'])

assuming inputs is the dictionary that feeds the model. The loss function will leverage the class weights. Documentation for CrossEntropyLoss can be found here.

If, instead, you’re using Trainer, you’ll have to change the compute_loss method as informed previously:

As an example, modifying the original implementation, it’d be something like

from transformers import Trainer
import torch

class MyTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # You pass the class weights when instantiating the Trainer
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.
        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            # We don't use .loss here since the model may return tuples instead of ModelOutput.

            # Changes start here
            # loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            logits = outputs['logits']
            criterion = torch.nn.CrossEntropyLoss(weights=self.class_weights)
            loss = criterion(logits, inputs['labels'])
            # Changes end here

        return (loss, outputs) if return_outputs else loss

if you’re not using label smoothing.

1 Like