Change the classifcation threshold

Hello everyone,
I have a binary text classification task, that I solve with the help of bert-base-multilingual-uncased and the HF BertForSequenceClassification. As I have very unbalanced cost for false positives and false negatives, I was wondering if it is possible to change the classification threshold.
In other words is it possible to adjust the classification head, such that:
predictions(x) = 1 if threshold > 0.7 and prediction(x) = 0 if threshold<=0.7
instead of the usual
predictions(x) = 1 if threshold > 0.5 and prediction(x) = 0 if threshold<=0.5?

Thanks in advance for any solutions, tips or hints!

1 Like

Hello!

You can easily modify the classification threshold for your binary classification task using the output logits from the BertForSequenceClassification model. Here’s how you can do it step by step:

  1. Get the logits from the model output:
    When you pass your input to the BertForSequenceClassification model, the output contains logits (unnormalized scores). These logits are then converted to probabilities using the softmax function.

  2. Apply the sigmoid function for binary classification:
    Since this is a binary classification task, the logits can be passed through a sigmoid activation to map them to probabilities in the range [0, 1].

  3. Adjust the threshold:
    After calculating the probabilities, you can apply a custom threshold to determine the predictions. For example, use 0.7 instead of the default 0.5.

Here’s a Python example of how to implement this:

import torch
from transformers import BertTokenizer, BertForSequenceClassification

# Load model and tokenizer
model = BertForSequenceClassification.from_pretrained("bert-base-multilingual-uncased", num_labels=1)
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-uncased")

# Example text input
text = "Your input text here"
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)

# Get model outputs
with torch.no_grad():
    logits = model(**inputs).logits

# Apply sigmoid to get probabilities
probabilities = torch.sigmoid(logits).squeeze().item()

# Set a custom threshold
threshold = 0.7
prediction = 1 if probabilities > threshold else 0

print(f"Probability: {probabilities:.4f}")
print(f"Prediction: {prediction}")

Explanation

  • The logits from the model are converted to probabilities using the sigmoid function, as this is a binary classification task.
  • By setting threshold = 0.7, predictions are classified as 1 (positive) if the probability exceeds 0.7 and as 0 (negative) otherwise.

Custom Loss Adjustment (Optional)

If your false positive and false negative costs are imbalanced, you might also consider using a weighted loss function (like BCEWithLogitsLoss) during training to help the model better handle class imbalance. Here’s an example:

from torch.nn import BCEWithLogitsLoss

# Define class weights
pos_weight = torch.tensor([<your_positive_weight>])
criterion = BCEWithLogitsLoss(pos_weight=pos_weight)

# During training
loss = criterion(logits, labels)

Hope this help! :blush:

1 Like

Hello Alan,

many thanks for your reply, this helped a lot. As I want to score the model later via API, I solved the problem a bit different: I create a Classification pipeline, that inherits from TextClassificationPipeline and save it.
Maybe the following is interesting for someone who struggled at the same point I did :slight_smile:

from transformers import pipeline, TextClassificationPipeline
class ThresholdTextClassificationPipeline(TextClassificationPipeline):
    def __init__(self, model, tokenizer, threshold=0.7, **kwargs):
        super().__init__(model=model, tokenizer=tokenizer, **kwargs)
        self.threshold = threshold

    def __call__(self, *args, **kwargs):
        outputs = super().__call__(*args, **kwargs)
        for output in outputs:
            output['label'] = 'LABEL_1' if output['score'] > self.threshold else 'LABEL_0'
        return outputs

threshold_pipe = ThresholdTextClassificationPipeline(
    model = AutoModelForSequenceClassification.from_pretrained(model_output_dir),
    tokenizer=tokenizer,
    threshold=threshold,
    framework="pt",
    batch_size=batch_size
)
1 Like