I am working on an imbalanced binary classification problem. The important class is the minority class. To decrease the overfit towards the majority class, I am experimenting with cost-sensitive training. To update the loss using the trainer API, I am overriding the loss function with a custom loss function according to the transformers documentation and passing weights to the loss function using class distribution, This is my custom compute loss function and training code
import torch.nn as nn
import torch.nn.functional as F
from sklearn.utils import class_weight
y_train = tokenized_datasets['train']['label']
class_weights = torch.tensor(class_weight.compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train), device=device).float()
class CustomTrainer(Trainer):
"""
Class that overrides default compute_loss to enable cost-sensitive training.
"""
def compute_loss(self, model, inputs, return_outputs=False):
n_labels = self.model.config.num_labels
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.logits
loss_fct = nn.CrossEntropyLoss(
weight=class_weights
)
loss = loss_fct(logits.view(-1, n_labels), labels)
return (loss, outputs) if return_outputs else loss
model_output_path = "../../trained_models/cost_sensitve_training/loss_weights_class_distribution/"
from transformers import TrainingArguments, EarlyStoppingCallback
training_args = TrainingArguments(model_output_path,
num_train_epochs=10,
evaluation_strategy="epoch",
save_strategy="epoch",
metric_for_best_model='eval_loss',
per_device_train_batch_size=16,
per_device_eval_batch_size=16,
load_best_model_at_end=True,
fp16=True,
)
early_stopping = EarlyStoppingCallback(early_stopping_patience=3)
trainer = CustomTrainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['dev'],
callbacks=[early_stopping],
tokenizer=tokenizer,
data_collator= data_collator
)
trainer.train()
trainer.save_model(output_dir=model_output_path)
trainer.save_state()
However, during inference, the logits/probabilities (softmax of the logits) per each class for the entire test dataset is having a very small variance. This is the distribution of class 1 and the PR curve.
Can anyone help me debug the reason for this? I found someone similar results when training with custom trainers.
Count | 1025 |
---|---|
mean | 0.391846 |
Max | 0.392090 |
Min | 0.391602 |