Most efficient multi-label classifier?

Below is the code sample that I managed to make it work by using the multi_label_classification question



import torch
from torch.utils.data.dataset import Dataset

from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Example data. 
# In reality, the strings are usually longer and there are 11 possible classes
texts = [
    "This is the first sentence.",
    "This is the second sentence.",
    "This is another sentence.",
    "Finally, the last sentence.",
]

labels = [
    [0.99, 0.91, 0.11, 0.10, 0.01],
    [0.89, 0.51, 0.01, 0.10, 0.01],
    [0.39, 0.21, 0.11, 0.10, 0.11],
    [0.29, 0.91, 0.51, 0.20, 0.51],
]


train_texts = texts[:2]
train_labels = labels[:2]

eval_texts = texts[2:]
eval_labels = labels[2:]

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

train_encodings = tokenizer(train_texts, padding="max_length", truncation=True, max_length=512)
eval_encodings = tokenizer(eval_texts, padding="max_length", truncation=True, max_length=512)


class TextClassifierDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

train_dataset = TextClassifierDataset(train_encodings, train_labels)
eval_dataset = TextClassifierDataset(eval_encodings, eval_labels)

model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    problem_type="multi_label_classification",
    num_labels=5
)

training_arguments = TrainingArguments(
    output_dir=".",
    evaluation_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
)

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

trainer.train()
2 Likes