Multilabel classification performance metrics using Trainer API


My goal is to output certain model performance metrics for my multilabel classification problem (I am using a DistilBERT architecture by the way). If I look at each of the labels individually you can say most of the labels are really unbalanced. Given this I also want to correct for the label (or class) imbalance.

I am fairly new to this and by looking at some examples, and trying myself I have done the following:

def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True): 
    y_pred = torch.from_numpy(y_pred)
    y_true = torch.from_numpy(y_true)
    if sigmoid:
        y_pred = y_pred.sigmoid()
    return ((y_pred>thresh)==y_true.bool()).float().mean().item()

The above code calculates model accuracy given a threshold of 0.5.

Next I used this code and I also included the above function as an output

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    y_true = labels
    y_pred = sigmoid(eval_pred.predictions) 
    y_pred = (y_pred>0.5).astype(float)
    clf_dict = classification_report(y_true, y_pred, target_names=all_labels,
                                         zero_division=0, output_dict=True)
    return {"accuracy_thresh": accuracy_thresh(predictions, labels), "micro f1": clf_dict['micro avg']['f1-score'], "macro f1": clf_dict['macro avg']['f1-score'],
           "weighted f1": clf_dict['weighted avg']['f1-score']}

It looks a bit hacky, but it works (it runs). I have the Trainer als follows:

class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels") #keeps the labels
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight = class_weights)
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), 
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss

Note: I am actually using pos_weights. Why? Since I am dealing with imbalanced labels as said above I have a tensor which contains for each label a weight calculated as number of negative cases / positive cases.

The trainer then is

multi_trainer = MultilabelTrainer(

My main question is: Does it actually make sense for what I have done? That is am I actually getting the right performance metrics taking into account I want to correct for imbalance? Or is there a better alternative (e.g. less verbose) to achieve this?


I’ve created a notebook for you to illustrate this: Transformers-Tutorials/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb at master · NielsRogge/Transformers-Tutorials · GitHub

Actually, there’s no need for a MultilabelTrainer anymore, as you can just set the problem_type of the model’s configuration to “multi_label_classification”.

1 Like

Thank you very much!