Class weights for bertForSequenceClassification

I have an unbalanced data with a couple of classes with relatively smaller sample sizes. I am wondering if there is a way to assign the class weights to BertFor SequenceClassification class, maybe in BertConfig ?, as we can do in nn.CrossEntropyLoss.

Thank you, in advance!

1 Like

No, you need to compute the loss outside of the model for this. If you’re using Trainer, see here on how to change the loss form the default computed by the model.

Hi Sylvain,

Glad to hear from you, outside of FastAI :slight_smile: Well, I am here as “Beginner” and will have to study more about Trainer. In the meantime, I tried the following:

  1. run BertForSequenceClassification as usual
  2. Take out logits from output (discard the loss from Bert run)
  3. calculate new loss from nn.CrossEntropyLoss
  4. and then calculate loss.backward()

Model runs okay, but I am not sure if this is a legitimate approach…



That is the correct approach!

Fantastic!!! This means a lot that I got support from somebody like you, Sylvain! Thanks!! :laughing:

How do I take out the logits and calculate the new loss? Do you have a code example/snippet?


Hi there,

I’m assuming you’re using the standard BertForSequenceClassification. So, instead of doing

outputs = model(**inputs)
loss = outputs['loss']

you do

outputs = model(**inputs)
logits = outputs['logits']
criterion = torch.nn.CrossEntropyLoss(weights=class_weights)
loss = criterion(logits, inputs['labels'])

assuming inputs is the dictionary that feeds the model. The loss function will leverage the class weights. Documentation for CrossEntropyLoss can be found here.

If, instead, you’re using Trainer, you’ll have to change the compute_loss method as informed previously:

As an example, modifying the original implementation, it’d be something like

from transformers import Trainer
import torch

class MyTrainer(Trainer):
    def __init__(self, class_weights, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # You pass the class weights when instantiating the Trainer
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False):
        How the loss is computed by Trainer. By default, all models return the loss in the first element.
        Subclass and override for custom behavior.
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
            labels = None
        outputs = model(**inputs)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
            # We don't use .loss here since the model may return tuples instead of ModelOutput.

            # Changes start here
            # loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
            logits = outputs['logits']
            criterion = torch.nn.CrossEntropyLoss(weights=self.class_weights)
            loss = criterion(logits, inputs['labels'])
            # Changes end here

        return (loss, outputs) if return_outputs else loss

if you’re not using label smoothing.

1 Like

@lucasresck as per your suggestion, I did the following

train_losses = []
num_mb_train = len(train_dataloader)

import torch.nn as nn
import numpy as np


for epoch in range(num_epochs):
  train_loss = 0

  for step, batch in enumerate(train_dataloader):
    batch = tuple( for t in batch)
    b_input_ids, b_input_mask,b_token_type, b_labels = batch

    outputs = model(b_input_ids, token_type_ids=b_token_type, attention_mask=b_input_mask, labels=b_labels)
    criterion = torch.nn.CrossEntropyLoss(weight=weights,reduction='mean')
    loss = criterion(outputs[1], b_labels)
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    train_loss = loss.item()

    if (step) % 50 == 0:
      print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' 
             .format(epoch+1, num_epochs, step+1, total_steps, loss.item()))

Am I doing the right way or not?

1 Like

heads up, the param for torch.nn.CrossEntropyLoss is weight (not weights)

1 Like

Supposing outputs[1] refers to outputs['logits'], I believe it’s correct.

Thanks for the correction. I would edit the answer if possible.

1 Like