Fine-Tune for MultiClass or MultiLabel-MultiClass

Hi again @dikster99, thanks to a tip from Sylvain Gugger, I realised that there’s a much simpler way to implement multi-label classification: just override the compute_loss function of the Trainer!

Here’s an example in PyTorch:

class MultilabelTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits
        loss_fct = torch.nn.BCEWithLogitsLoss()
        loss = loss_fct(logits.view(-1, self.model.config.num_labels), 
                        labels.float().view(-1, self.model.config.num_labels))
        return (loss, outputs) if return_outputs else loss

and I’ve updated my Colab notebook to reflect the change. Hope that helps!

PS you will need to install transformers from the master branch for this to work, i.e. pip install git+https://github.com/huggingface/transformers.git

10 Likes