Hi there,
I’m assuming you’re using the standard BertForSequenceClassification
. So, instead of doing
outputs = model(**inputs)
loss = outputs['loss']
you do
outputs = model(**inputs)
logits = outputs['logits']
criterion = torch.nn.CrossEntropyLoss(weights=class_weights)
loss = criterion(logits, inputs['labels'])
assuming inputs
is the dictionary that feeds the model. The loss function will leverage the class weights. Documentation for CrossEntropyLoss
can be found here.
If, instead, you’re using Trainer
, you’ll have to change the compute_loss
method as informed previously:
As an example, modifying the original implementation, it’d be something like
from transformers import Trainer
import torch
class MyTrainer(Trainer):
def __init__(self, class_weights, *args, **kwargs):
super().__init__(*args, **kwargs)
# You pass the class weights when instantiating the Trainer
self.class_weights = class_weights
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
if self.label_smoother is not None and "labels" in inputs:
labels = inputs.pop("labels")
else:
labels = None
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
if labels is not None:
loss = self.label_smoother(outputs, labels)
else:
# We don't use .loss here since the model may return tuples instead of ModelOutput.
# Changes start here
# loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
logits = outputs['logits']
criterion = torch.nn.CrossEntropyLoss(weights=self.class_weights)
loss = criterion(logits, inputs['labels'])
# Changes end here
return (loss, outputs) if return_outputs else loss
if you’re not using label smoothing.