I have an unbalanced data with a couple of classes with relatively smaller sample sizes. I am wondering if there is a way to assign the class weights to BertFor SequenceClassification class, maybe in BertConfig ?, as we can do in nn.CrossEntropyLoss.
No, you need to compute the loss outside of the model for this. If you’re using Trainer, see here on how to change the loss form the default computed by the model.
Glad to hear from you, outside of FastAI Well, I am here as “Beginner” and will have to study more about Trainer. In the meantime, I tried the following:
run BertForSequenceClassification as usual
Take out logits from output (discard the loss from Bert run)
calculate new loss from nn.CrossEntropyLoss
and then calculate loss.backward()
Model runs okay, but I am not sure if this is a legitimate approach…
assuming inputs is the dictionary that feeds the model. The loss function will leverage the class weights. Documentation for CrossEntropyLoss can be found here.
If, instead, you’re using Trainer, you’ll have to change the compute_loss method as informed previously:
from transformers import Trainer
import torch
class MyTrainer(Trainer):
def __init__(self, class_weights, *args, **kwargs):
super().__init__(*args, **kwargs)
# You pass the class weights when instantiating the Trainer
self.class_weights = class_weights
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
# Changes start here
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
logits = outputs['logits']
criterion = torch.nn.CrossEntropyLoss(weights=self.class_weights)
loss = criterion(logits, inputs['labels'])
# Changes end here
return (loss, outputs) if return_outputs else loss