Class weights for Segformer loss function

I am training a SegformerForSemanticSegmentation model for three classes. I trained the model using a normal trainer and got some okay results. I want to change the class weights to [0.2, 0.4, 0.4] since I have an unbalanced dataset. I searched using custom weights and found this code:

loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([0.2, 0.4, 0.4], device='cuda', dtype=torch.float))

class WeightedTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs.get("labels").contiguous()
        # forward pass
        outputs = model(**inputs)
        logits = outputs.get("logits").contiguous()
        upsampled_logits = F.interpolate(
            logits,
            size=labels.shape[1:],
            mode='bilinear',
            align_corners=False
        )
        loss = loss_fct(upsampled_logits.view(-1, 3), labels.view(-1))
        return (loss, outputs) if return_outputs else loss

However, after changing the’ weight’ argument, I observed a considerable performance drop when I used this code. The loss stays fixed and does not decrease. My mIoU dropped from around 0.85 to 0.2. The loss usually decreases from around 1 to 0.2 with the normal trainer, but it stays at 1 with this one. I kept the weights equal to [1.0, 1.0, 1.0], but the same issue was still observed. The problem is not with the weights but because the loss used in SegFormer and the above loss function are different. Could someone please help me implement a weighted trainer for SegFormer? Just for reference, here is my training code:

variant = 'b4'
pretrained_model_name = f"nvidia/mit-{variant}"
dp = 0.3
model = SegformerForSemanticSegmentation.from_pretrained(
    pretrained_model_name,
    id2label=id2label,
    label2id=label2id,
    classifier_dropout_prob=dp
)
epochs = 100
lr = 1e-4
batch_size = 32
wd = 1e-2
exp_name = f"{variant}-lr={lr}-aug2more-dev-wd={wd}-dp={dp}"

training_args = TrainingArguments(
    exp_name,
    learning_rate=lr,
    num_train_epochs=epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    save_total_limit=3,
    evaluation_strategy="steps",
    save_strategy="steps",
    save_steps=10,
    eval_steps=10,
    logging_steps=1,
    load_best_model_at_end=True,
    logging_dir=f"logs/{exp_name}",
    warmup_ratio=0.1,
    weight_decay=wd,
)

trainer = Trainer(
    model=model, # type: ignore
    args=training_args,
    train_dataset=train_ds, # type: ignore
    eval_dataset=valid_ds, # type: ignore
    compute_metrics=compute_metrics,
)
trainer.train()

I have the same issue