Defining custom compute_metrics for multiclass classification

Hi, I’m currently attempting to train a classifier to infer technical specifications of sailboats by fine tuning a ViT. Up until now I have trained a model for each specification seperately, but I am interested in trying to make a multitask model and see if it benefits from learning more than one specification at a time. I have done this by adding 5 classification heads to the end of the ViT like this:

class MultitaskViT(nn.Module):
    def __init__(self):
        super(MultitaskViT, self).__init__()
        self.base_model = AutoModel.from_pretrained(checkpoint , id2label = None , label2id = None)
        self.linear1 = nn.Linear(768, 1024)
        self.SoftMax = nn.Softmax(dim=1)
        self.Hull_Type = nn.Linear(1024, (Hull_Type_Classes.__len__()))
        self.Rigging_Type = nn.Linear(1024, (Rigging_Type_Classes.__len__()))
        self.Construction = nn.Linear(1024, (Construction_Classes.__len__()))
        self.Ballast_Type = nn.Linear(1024, (Ballast_Type_Classes.__len__()))
        self.Designer = nn.Linear(1024, (Designer_Classes.__len__()))

    def forward(self, **inputs):
        outputs = self.base_model(inputs['pixel_values'])['pooler_output']
        outputs = self.linear1(outputs)
        hull_type = self.SoftMax(self.Hull_Type(outputs))
        rigging_type = self.SoftMax(self.Rigging_Type(outputs))
        construction = self.SoftMax(self.Construction(outputs))
        ballast_type = self.SoftMax(self.Ballast_Type(outputs))
        designer = self.SoftMax(self.Designer(outputs))
        return {"Hull Type" : hull_type,
                "Rigging Type" : rigging_type,
                "Construction" : construction,
                "Ballast Type" : ballast_type,
                "Designer" : designer}

I am going to train the model using a transformers trainer where I overwrite the compute_loss function to a custom cross entropy loss function that adds up the loss of all the prediction heads.

class MultiTaskTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    def compute_loss(self, model, inputs):
        criterion = nn.CrossEntropyLoss()
        model_output = model(**inputs)
        total_loss = 0
        for i in range (len(model_output)):    
            total_loss += criterion(model_output[label_types[i]], inputs[label_types[i]])
        return total_loss

However, I would also like to log some metrics to weights and biases while training. Preferably I would like to log the accuracy, f1, recall and precision seperately for each prediction head. I have made the following compute_metrics function but it does not log anything to wandb

def compute_metrics_multitask(eval_pred):
    metrics = {}
    for i , label in enumerate(label_types):
        predictions, labels = eval_pred[i]
        predictions = np.argmax(predictions, axis=1)
        accuracy_score = accuracy.compute(predictions=predictions, references=labels).values()
        f1_score = f1.compute(predictions=predictions, references=labels , average="macro").values()
        precision_score = precision.compute(predictions=predictions, references=labels , average="macro").values()
        recall_score = recall.compute(predictions=predictions, references=labels , average="macro").values()
        metrics["accuracy_"+label] = accuracy_score
        metrics["f1_"+label] = f1_score
        metrics["precision_"+label] = precision_score
        metrics["recall_"+label] = recall_score
    return metrics

I suspect that it is the way I handle the “eval_pred” input which is wrong.

Help would be appreciated :slight_smile: