The Best Approach for Weighted Multilabel Classification

Hello.

I have a task in which there are 6 different labels for each record, and every label can have a value from 0 to 3. The dataset is so imbalanced.

text label_1 label_2 label_3 label_4 label_5 label_6
… 0 1 0 2 0 0
… 0 0 0 0 0 0
… 2 0 0 0 0 3

I want to solve this task using transformers. Should I set the num_labels equal to 24 while initializing the transformer?

num_labels = 6  # Number of labels
classes_per_label = 4  # Number of intensity levels per label (0, 1, 2, 3)
total_classes = num_labels * classes_per_label

model = AutoModelForSequenceClassification.from_pretrained(model_name,
                                                           problem_type="multi_label_classification",
                                                           ignore_mismatched_sizes=True,
                                                           num_labels=total_classes)

In addition, what are best practices for 1. creating a Dataset object from torch.utils.data.Dataset module, 2. defining a loss function, and 3. defining thresholds while predicting and evaluating the labels?

Here is my current code:

def encode_data(df, tokenizer, label_columns):
    encodings = tokenizer(list(df['text']), padding=True, truncation=True, max_length=128)
    labels = df[label_columns].values
    return encodings, labels

class WeightedMultiLabelDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

# Prepare datasets
train_encodings, train_labels = encode_data(train_df, tokenizer, label_columns)
dev_encodings, dev_labels = encode_data(dev_df, tokenizer, label_columns)

train_dataset = WeightedMultiLabelDataset(train_encodings, train_labels)
dev_dataset = WeightedMultiLabelDataset(dev_encodings, dev_labels)
from sklearn.metrics import classification_report, average_precision_score

def compute_metrics(pred):
    logits, labels = pred
    
    logits = logits.reshape(-1, classes_per_label)
    probabilities = torch.softmax(torch.tensor(logits), axis=1).view(-1, num_labels).numpy()
    predictions = torch.argmax(torch.tensor(logits), axis=1).view(-1, num_labels).numpy()
    labels = labels.reshape(-1, num_labels).numpy()

    auprc_per_label = []
    for i in range(num_labels):
        auprc = average_precision_score(labels[:, i], probabilities[:, i])
        auprc_per_label.append(auprc)
    
    mean_auprc = sum(auprc_per_label) / len(auprc_per_label)

    report = classification_report(labels, predictions, target_names=label_columns, zero_division=0)
    print(report)

    return {
        'mean_auprc': mean_auprc,
        'auprc_per_label': auprc_per_label,
    }

Thank you!

1 Like

Hi there, I read your question and can see you’re working on an interesting multi-label classification task. Let me help clarify your doubts and provide some guidance on best practices.

First, regarding num_labels, setting it equal to 24 (6 labels × 4 intensity levels) is incorrect. For your case, each label is independent and can take one of four values (0, 1, 2, 3). You should set num_labels = 6 when initializing your transformer. This is because you’re solving a multi-label classification problem, where each label is treated as a separate classification task with its own probabilities.

For the rest of your queries, here are my suggestions:

1. Creating a Dataset Object

Your current implementation of the WeightedMultiLabelDataset is good, but since your task deals with integer values (0–3) for each label, you need to ensure the labels are properly encoded. You should consider using torch.float instead of torch.long if you’re working with one-hot or probabilities for evaluation.

Also, verify that your tokenizer outputs include all necessary fields like input_ids, attention_mask, and optionally token_type_ids.

2. Defining the Loss Function

For this task, you can use torch.nn.CrossEntropyLoss for each label since your labels are categorical with four classes. Since your dataset is imbalanced, consider using class weights to handle the imbalance effectively. Here’s an example:

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)  

You can calculate class_weights using the frequency of each class in your dataset.

3. Defining Thresholds for Prediction and Evaluation

During prediction, you can use torch.softmax to get the probabilities for each intensity level. To evaluate, you can use torch.argmax to select the most probable intensity level for each label. No additional thresholds are necessary since your task involves classification rather than binary decisions.

Here’s how you can adjust your code:

logits = logits.reshape(-1, classes_per_label)
probabilities = torch.softmax(torch.tensor(logits), axis=1).view(-1, num_labels).numpy()
predictions = torch.argmax(torch.tensor(logits), axis=1).view(-1, num_labels).numpy()

Additional Suggestions

  1. Handle Imbalance: Use WeightedRandomSampler during training to address class imbalance.
  2. Evaluation Metrics: In addition to AUPRC, consider metrics like F1-score, accuracy, and Matthews correlation coefficient for a more comprehensive evaluation.
  3. Batch Processing: Ensure that you are batching your data correctly and using the appropriate device (e.g., GPU) for faster training.

Example Adjustments

Here’s a slightly modified version of your dataset class:

class WeightedMultiLabelDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = torch.tensor(labels, dtype=torch.float)  # Use float if needed for evaluation

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels'] = self.labels[idx]
        return item

Your approach is solid! By following these adjustments, you should be able to handle the multi-label classification task effectively. Let me know if you need further clarification or assistance. Good luck! :blush:

1 Like