Hi,

I want to do multi label text classification.

My dataset has around 40,000 rows and around 1500 binary labels (0/1).

The number of 1s for any label is very low compared to the total number of rows.

I want to use BertModelForSequenceClassification with problem type set to â€śmulti_label_classificationâ€ť.

I have extended the Trainer class to add class weights in the â€śpos_weightsâ€ť parameter for the BCEWithLogitsLoss.

I want to know what things I can try to solve my problem because my F1 score is quite low.

Solution: just experiment with different weights. For me this formula seemed to improve the results.

```
weights = []
total_count = len(df)
from math import sqrt
temperature = 0.16
for column in df.columns[1:]:
pos_count = df[column].sum()
neg_count = total_count - pos_count
pos_count = 1/sqrt(pos_count * temperature)
neg_count = 1/sqrt(neg_count * temperature)
weights.append(neg_count/pos_count)
weights = FloatTensor(weights)*70
```

Also note that the initial f1 / precision / recall / accuracy scores are bound to be low due to the sparsity hence you can increase the number of epochs to say 50 for better results. Please keep in mind the validation loss should not start increasing.

Hereâ€™s the weighted trainer implementation I am using:

```
from typing import Optional
from torch import FloatTensor
from torch.nn import BCEWithLogitsLoss
import logging
class WeightedTrainer(Trainer):
def __init__(self, *args, class_weights: Optional[FloatTensor] = None, **kwargs):
super().__init__(*args, **kwargs)
if class_weights is not None:
class_weights = class_weights.to(self.args.device)
logging.info(f"Using multi-label classification with class weights", class_weights)
self.loss_fct = BCEWithLogitsLoss(pos_weight=class_weights)
def compute_loss(self, model, inputs, return_outputs=False):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
labels = inputs.pop("labels")
outputs = model(**inputs)
try:
loss = self.loss_fct(outputs.logits.view(-1, model.num_labels), labels.view(-1, model.num_labels))
except AttributeError: # DataParallel
loss = self.loss_fct(outputs.logits.view(-1, model.module.num_labels), labels.view(-1, model.num_labels))
return (loss, outputs) if return_outputs else loss
```