Most efficient multi-label classifier?


I’m trying to train a model in Tensorflow to classify text according to a fixed set of 5 labels. For example, let’s say I feed my model the following text:

“my advice is that you go ahead with your plans to learn Python, because its syntax is easy for beginners. It’s also great for snake lovers like me!”

After sniffing the text, the model would, ideally, report back how much the text matches my pre-defined labels:

       Label             Prediction
--------------------     ----------
programming_advice          0.99
advice_for_beginners        0.91
cooking_advice              0.11
health_advice               0.10
not_advice                  0.01

My question

What is the most efficient way to build such a classifier? I’ve seen several options to do this, but I’m not sure which one would be best:

  1. Fine-tune five different binary classifiers, since there are five labels… but this would take forever to train, so I assume there must be a better way.
  2. Make a model with a transformer only, and train it.
  3. Make a model with a transformer plus my own Dense layers, and train it.
  4. Make a model with a transformer plus my own Dense layers—but freeze the transformer as-is, and only train the Dense layers.
    • Freezing is a common practice with pre-trained computer vision models; I don’t know whether it’s also good practice for NLP too.

I would be grateful for any suggestions on which of 1-4 works best. I’m still rather new around here, but the Huggingface community is extremely welcoming and helpful, and I appreciate being here! A big thanks for anybody who can help give me some pointers.


I am facing the same problem and as none replied yet I wanted to ask if you got any updates/new thoughts on this? Cheers


Option 2 is indeed the best. To train a multi-label classifier, you can use an xxxForSequenceClassification model (which is a Transformer encoder with a linear layer on top), and set the problem_type attribute of the configuration to multi_label_classification. For example, if you want to use BERT, you can do it as follows:

from transformers import BertForSequenceClassification

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", problem_type="multi_label_classification")

As you can see in the code, it will use the BCE (binary cross-entropy) loss.

Note that if the number of labels you have is > 2, you also need to specify num_labels=... when calling the .from_pretrained() method.


Below is the code sample that I managed to make it work by using the multi_label_classification question

import torch
from import Dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Example data. 
# In reality, the strings are usually longer and there are 11 possible classes
texts = [
    "This is the first sentence.",
    "This is the second sentence.",
    "This is another sentence.",
    "Finally, the last sentence.",

labels = [
    [0.99, 0.91, 0.11, 0.10, 0.01],
    [0.89, 0.51, 0.01, 0.10, 0.01],
    [0.39, 0.21, 0.11, 0.10, 0.11],
    [0.29, 0.91, 0.51, 0.20, 0.51],

train_texts = texts[:2]
train_labels = labels[:2]

eval_texts = texts[2:]
eval_labels = labels[2:]

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

train_encodings = tokenizer(train_texts, padding="max_length", truncation=True, max_length=512)
eval_encodings = tokenizer(eval_texts, padding="max_length", truncation=True, max_length=512)

class TextClassifierDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

train_dataset = TextClassifierDataset(train_encodings, train_labels)
eval_dataset = TextClassifierDataset(eval_encodings, eval_labels)

model = AutoModelForSequenceClassification.from_pretrained(

training_arguments = TrainingArguments(

trainer = Trainer(

1 Like