Unable to train the model with weighted cross entropy


I want to fine-tune a model for sentence classification. My dataset has five classes, and it’s imbalanced.
I use Trainer to train the model.
At the first few epochs, F1 and Accuracy get improved for both the train and validation set.
Further epochs improve the results on the train set only, which, I guess, means the model is getting overfitted.
Since, the dataset is imbalanced, I guess that could be the reason for overfitting.
I tried to use CrossEntropyLoss with class weights, but doing so prevented the model from learning anything.
The F1 and Accuracy for train and validation keep dropping since the first epoch when I use class weights. Seems like the model is learning nothing with CrossEntropy with class weights.
I calculate the class weight in this way:

stratify_column_name = "labels"
train_test = dataset.class_encode_column(stratify_column_name).train_test_split(test_size=0.1,stratify_by_column=stratify_column_name)
num_classes = len(train_test['train'].unique('labels'))

class_frequency = Counter()
class_weight = [(class_frequency.total() / (class_frequency[i] * num_classes)) for i in range(0, len(class_frequency))]

And then define a custom Trainer.

class WeightedTrainer(Trainer):
    def __init__(self, *args, class_weights: Optional[FloatTensor] = None, **kwargs):
        super().__init__(*args, **kwargs)
        if class_weights is not None:
            class_weights = torch.tensor(class_weights).to(self.accelerator.device)
            print(f"Using multi-label classification with class weights", class_weights)
        self.loss_fct = torch.nn.CrossEntropyLoss(weight=class_weights)

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        loss = self.loss_fct(outputs.logits.view(-1, self.model.config.num_labels), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

Am I doing anything wrong? Using default Trainer, the model is able to learn and overfits, but using the WeightedTrainer prevents the model from learning anything. I can see the Accuracy or F1 is not increased at all for the train set.

Any help is greatly appreciated.