Multilabel text classification Trainer API

Hi all,
Can someone help me to do a multilabel classification with the Trainer API ?

2 Likes

Sure, all you need to do is make sure the problem_type of the model’s configuration is set to multi_label_classification, e.g.:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=10, problem_type="multi_label_classification")

This will make sure the appropriate loss function is used (namely, binary cross entropy). Note that the current version of Transformers does not support this problem_type for any model, but the next version of Transformers will (as per PR #14180).

I suggest taking a look at the example notebook to do multi-label classification using the Trainer.

Update: I made a notebook myself to illustrate how to fine-tune any encoder-only Transformer model for multi-label text classification: Transformers-Tutorials/Fine_tuning_BERT_(and_friends)_for_multi_label_text_classification.ipynb at master · NielsRogge/Transformers-Tutorials · GitHub

11 Likes

Hi, may I ask here where we can find which models are supported for multilabel classification?
Thank you in advance

Hello @nielsr !
Thanks a lot for your example ! I’ve tried it on my data and accuracy stay at 0 and roc auc at 0.5. I’m clearly having an issue but I can’t find why.
I have 420 labels, which may be the reason why i’m having this issue ?
I’m beginner and I clearly don’t know where to start to fix this, any help would be greatly appreciated :slight_smile:

Thanks !

1 Like

I am too a beginner but 420 labels a lot, like if there’s a lot of training data then it’s fine but . . . if everything’s right with the code then - less data may be the problem.

Hi Niels!

I’m not sure if this has had too much time pass for this question, but when running your notebook with my data, I keep running into this error which I can’t effectively troubleshoot:

“Classification metrics can’t handle a mix of multiclass-multioutput and multilabel-indicator targets”

Any ideas? The only difference I can see is that my data is originally coded as binary flags (1,0) rather than boolean (T,F)

Hi,

If you have a code reproducer, that would make it easier for me to help :slight_smile:

1 Like

As much as I appreciate your response, I think I magically fixed it! Thanks (Bedankt!) again for sharing your code in the first place.

1 Like

Actually, to formally close this out. I was running into a classification error because I had a rogue non binary value in my validation data (a 2 instead of a 1) which was causing the error, and it would only occur if that row was used for validation.