Hello,
My goal is to output certain model performance metrics for my multilabel classification problem (I am using a DistilBERT architecture by the way). If I look at each of the labels individually you can say most of the labels are really unbalanced. Given this I also want to correct for the label (or class) imbalance.
I am fairly new to this and by looking at some examples, and trying myself I have done the following:
def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True):
y_pred = torch.from_numpy(y_pred)
y_true = torch.from_numpy(y_true)
if sigmoid:
y_pred = y_pred.sigmoid()
return ((y_pred>thresh)==y_true.bool()).float().mean().item()
The above code calculates model accuracy given a threshold of 0.5.
Next I used this code and I also included the above function as an output
def compute_metrics(eval_pred):
predictions, labels = eval_pred
y_true = labels
y_pred = sigmoid(eval_pred.predictions)
y_pred = (y_pred>0.5).astype(float)
clf_dict = classification_report(y_true, y_pred, target_names=all_labels,
zero_division=0, output_dict=True)
return {"accuracy_thresh": accuracy_thresh(predictions, labels), "micro f1": clf_dict['micro avg']['f1-score'], "macro f1": clf_dict['macro avg']['f1-score'],
"weighted f1": clf_dict['weighted avg']['f1-score']}
It looks a bit hacky, but it works (it runs). I have the Trainer als follows:
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels") #keeps the labels
outputs = model(**inputs)
logits = outputs.logits
loss_fct = torch.nn.BCEWithLogitsLoss(pos_weight = class_weights)
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss
Note: I am actually using pos_weights. Why? Since I am dealing with imbalanced labels as said above I have a tensor which contains for each label a weight calculated as number of negative cases / positive cases.
The trainer then is
multi_trainer = MultilabelTrainer(
model,
args,
train_dataset=train_dataset,
eval_dataset=test_dataset,
compute_metrics=compute_metrics,
tokenizer=tokenizer)
My main question is: Does it actually make sense for what I have done? That is am I actually getting the right performance metrics taking into account I want to correct for imbalance? Or is there a better alternative (e.g. less verbose) to achieve this?