Multi-label token classification

Hi @murdockthedude. I’m using some sensitive (biomedical) data and my use case is actually a little more complicated than ‘just’ multi-label NER, so I’d have to make up some dummy data and simplify my notebook a bit. Let me think more about that and how to make a shareable notebook.

In the mean time, the answer to your first two questions.

  1. The customer trainer is all that’s needed, although you probably also want to implement a special compute_metrics function and use that when you instantiate the trainer so you can do early stopping.
  2. In my case I used BertForTokenClassification but using AutoModelForTokenClassification should be fine, I think.
1 Like

Hey @drussellmrichie, totally understand, thank you. I’ll try to get a small notebook working too to see if I can tape this all together.

One question I have: Assuming I implement the custom trainer approach above, at inference time for multi label token classification, do you just take the individual output logits and run them through a sigmoid activation to get your final per-label-per-token probabilistic values (as opposed to single label which runs them through a softmax)? Or is it something more complex than that?

Thanks again!

@murdockthedude . I don’t think you even need to bother with sigmoid – you can just pass the logits through the sign function as in lambda x: 1 if x > 0 else 0. There are efficient TF and pytorch functions for that.

1 Like

sorry for the slow reply, this is great. I’m working on getting things functioning but this approach seems promising thus far. Will update back…

Hey, sorry for opening this old thread again, this looks pretty much exactly like what I am looking for at the moment. But how did you actually succeed in passing the one-hot encoded labels through to the Trainer?

Whenever I try this I am getting errors thrown by DataCollatorForTokenClassification that it expects the labels to be integers.


Can any one pls post your compute metrics function.

Help much appreciated.

@BunnyNoBugs @murdockthedude
I am working on morphological analysis problem where each token has multiple labels. Can you share sample notebook / working example so that I can understand and experiment with my problem ?

Can you share your code / sample working example?

Unfortunately I can’t. The code is in my old workplace’s private laptop and git repo. :frowning:

However, we published a paper based on this work here:

You might try emailing some of my coauthors, like Victor Ruiz. He might be willing to help you.

1 Like

You can take a look here. The code is a bit messy since it was done under time pressure in the end, but it gets the multi-token classification done, and we achieved good results with it.


Is the below data format correct for the above piece of code ?

features: [‘word’, ‘pos’, ‘noun_case’, ‘noun_gender’, ‘noun_number’],
num_rows: 22

Here, ‘pos’, ‘noun_case’, ‘noun_gender’, ‘noun_number’ are binary labels.

Hey everyone!

I also try to implement a BERT-based, two-headed model with one multi-label-classification head and a multi-class classification head. The two heads are not directly related to each other, they predict two different aspects of a token.

The challenge is now to combine these two heads into one model. So far, I made the following observations:

  • Use BCEWithLogitsLoss for multi-label classification (labelled with a vector of hot-encoded labels with 1 or 0)
  • Use CrossEntropyLoss for multi-class classification (labelled with an ID corresponding to the label)
  • Compute and return mean of both loss function results

My questions now:
Is it possible to implement two model heads based on the Trainer API? Or do I need to implement a Model class manually?
What are these “class_weights” mentioned in the solution of this post?
Is my assumption correct that I can just compute the mean of the two losses?
Is there any reason for NOT implementing this approach, but rather implementing two separate models?

Best regards,

How is your data encoded? I assume it’s not one-hot encoded, since you explicitly ignore the -100 label ID.

I think our paper has these details. I’d recommend checking that out. Should be open access:

1 Like

Thanks for replying, I’ll check the paper.