Hi again @dikster99, thanks to a tip from Sylvain Gugger, I realised that there’s a much simpler way to implement multi-label classification: just override the compute_loss
function of the Trainer
!
Here’s an example in PyTorch:
class MultilabelTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = torch.nn.BCEWithLogitsLoss()
loss = loss_fct(logits.view(-1, self.model.config.num_labels),
labels.float().view(-1, self.model.config.num_labels))
return (loss, outputs) if return_outputs else loss
and I’ve updated my Colab notebook to reflect the change. Hope that helps!
PS you will need to install transformers
from the master
branch for this to work, i.e. pip install git+https://github.com/huggingface/transformers.git