Model does not work after loss change

hello
My model finetunes bert (specifically Roberta) using a lst fully connected layer of a binary text classification task. I was using cross entropy loss and the code worked well. However when I changed the loss the model stopped learning and predicted 0 for all the examples and did not learn.
For other classification tasks the loss works fine.
The loss is decreasing but the accuracy stayes the same and the prediction is always 0.
I have tried different learning rate values and batch sizes and many other things but until now nothing worked. It happens when finetuning happens and also when the bert model is frozen. But not on other classification tasks.
The loss functio is RCE:
class ReverseCrossEntropy(torch.nn.Module):
def init(self, num_classes, scale=1.0):
super(ReverseCrossEntropy, self).init()
self.device = device
self.num_classes = num_classes
self.scale = scale

def forward(self, pred, labels):
    pred = F.softmax(pred, dim=1)
    pred = torch.clamp(pred, min=1e-7, max=1.0)
    label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float().to(self.device)
    label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
    rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
    return self.scale * rce.mean()

and I also tried NCE:
class NormalizedReverseCrossEntropy(torch.nn.Module):
def init(self, num_classes, scale=1.0):
super(NormalizedReverseCrossEntropy, self).init()
self.device = device
self.num_classes = num_classes
self.scale = scale

def forward(self, pred, labels):
    pred = F.softmax(pred, dim=1)
    pred = torch.clamp(pred, min=1e-7, max=1.0)
    label_one_hot = torch.nn.functional.one_hot(labels, self.num_classes).float().to(self.device)
    label_one_hot = torch.clamp(label_one_hot, min=1e-4, max=1.0)
    normalizor = 1 / 4 * (self.num_classes - 1)
    rce = (-1*torch.sum(pred * torch.log(label_one_hot), dim=1))
    return self.scale * normalizor * rce.mean()

They are taken from the artical [2006.13554] Normalized Loss Functions for Deep Learning with Noisy Labels
git: Active-Passive-Losses/loss.py at master · HanxunH/Active-Passive-Losses · GitHub
any help will be much appreciated.

class Model(nn.Module):

def init(self, device=‘cuda’, lm=‘roberta’, alpha_aug=0.8):
super().init()
if lm in lm_mp:
self.bert = AutoModel.from_pretrained(lm_mp[lm])
else:
self.bert = AutoModel.from_pretrained(lm)

self.device = device

# linear layer
hidden_size = self.bert.config.hidden_size
self.fc = torch.nn.Linear(hidden_size, 2)

def forward(self, x1, x2=None):
“”"Encode the left, right, and the concatenation of left+right.

Args:
    x1 (LongTensor): a batch of ID's

Returns:
    Tensor: binary prediction
"""
x1 = x1.to(self.device) # (batch_size, seq_len)
enc = self.bert(x1)[0][:, 0, :]

return self.fc(enc) 

creating the model:

device = ‘cuda’ if torch.cuda.is_available() else ‘cpu’
model = Model(device=device,
lm=hp.lm,
alpha_aug=hp.alpha_aug)
model = model.cuda()
optimizer = AdamW(model.parameters(), lr=hp.lr)
The training step is:

#deciding the loss
criterion = nn.CrossEntropyLoss()
for i, batch in enumerate(train_iter):
optimizer.zero_grad()
if len(batch) == 2:
x, y = batch
prediction = model(x)
else:
x1, x2, y = batch
prediction = model(x1, x2)

loss = criterion(prediction, y.to(model.device))
if hp.fp16:
    with amp.scale_loss(loss, optimizer) as scaled_loss:
        scaled_loss.backward()
else:
    loss.backward()
optimizer.step()
scheduler.step()
if i % 10 == 0: # monitoring
    print(f"step: {i}, loss: {loss.item()}")
del loss

This works well then the only change I did was for the loss:
criterion = ReverseCrossEntropy(2)
instead of cross entropy. And this change does not work.

Expected behavior
The result for training with cross entropy is:
step: 0, loss: 0.5812623500823975
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 16384.0
epoch 1: dev_f1=0.2772277227722772, f1=0.2745098039215686, best_f1=0.2745098039215686
step: 0, loss: 0.3767085075378418
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8192.0
epoch 2: dev_f1=0.36363636363636365, f1=0.35294117647058826, best_f1=0.35294117647058826
step: 0, loss: 0.43073320388793945
epoch 3: dev_f1=0.2978723404255319, f1=0.2978723404255319, best_f1=0.35294117647058826
step: 0, loss: 0.6784828305244446
epoch 4: dev_f1=0.5365853658536585, f1=0.43999999999999995, best_f1=0.43999999999999995
step: 0, loss: 0.25015905499458313
epoch 5: dev_f1=0.43076923076923085, f1=0.4745762711864407, best_f1=0.43999999999999995
step: 0, loss: 0.329183429479599
epoch 6: dev_f1=0.8148148148148148, f1=0.7647058823529412,
step: 0, loss: 0.08995085209608078
epoch 7: dev_f1=0.88, f1=0.8333333333333333, best_f1=0.8333333333333333
step: 0, loss: 0.18586984276771545
epoch 8: dev_f1=0.9032258064516129, f1=0.8750000000000001, best_f1=0.8750000000000001
step: 0, loss: 0.007164476439356804
epoch 9: dev_f1=0.888888888888889, f1=0.8275862068965518, best_f1=0.8750000000000001
step: 0, loss: 0.005751035641878843
epoch 10: dev_f1=0.9032258064516129, f1=0.8484848484848484, best_f1=0.8750000000000001
step: 0, loss: 0.14081726968288422
epoch 11: dev_f1=0.8571428571428571, f1=0.9032258064516129, best_f1=0.8750000000000001
step: 0, loss: 0.0045958105474710464
epoch 12: dev_f1=0.896551724137931, f1=0.9032258064516129, best_f1=0.8750000000000001
step: 0, loss: 0.0023396878968924284
epoch 13: dev_f1=0.8333333333333333, f1=0.888888888888889, best_f1=0.8750000000000001
step: 0, loss: 0.0017288422677665949
epoch 14: dev_f1=0.8750000000000001, f1=0.8750000000000001, best_f1=0.8750000000000001
step: 0, loss: 0.0025747090112417936
epoch 15: dev_f1=0.896551724137931, f1=0.896551724137931, best_f1=0.8750000000000001
step: 0, loss: 0.0030487636104226112
epoch 16: dev_f1=0.88, f1=0.888888888888889, best_f1=0.8750000000000001
step: 0, loss: 0.0015720207011327147
epoch 17: dev_f1=0.896551724137931, f1=0.896551724137931, best_f1=0.8750000000000001
step: 0, loss: 0.001150735653936863
epoch 18: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0009454995160922408
epoch 19: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0007868938846513629
epoch 20: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0006980099133215845
epoch 21: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0006197747425176203
epoch 22: dev_f1=0.9333333333333333, f1=0.896551724137931,
step: 0, loss: 0.0006151695270091295
epoch 23: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0004854918224737048
epoch 24: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.000492772669531405
epoch 25: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0004389513051137328
epoch 26: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0003859938296955079
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0
epoch 27: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0004301978333387524
epoch 28: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0004772722895722836
epoch 29: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0003848907945211977
epoch 30: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0003429920761846006
epoch 31: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0004783756739925593
epoch 32: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.00039960749563761055
epoch 33: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.00043797597754746675
epoch 34: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.00025380056467838585
epoch 35: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0003628128906711936
epoch 36: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.00036079881829209626
epoch 37: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.00036769770667888224
epoch 38: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0003665930707938969
epoch 39: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931
step: 0, loss: 0.0002882482949644327
epoch 40: dev_f1=0.9333333333333333, f1=0.896551724137931, best_f1=0.896551724137931

The expectation was the the rusults will be similar but when changed to reverse cross entropy the results are:
step: 0, loss: 3.970363140106201
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 16384.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 8192.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 4096.0
Gradient overflow. Skipping step, loss scaler 0 reducing loss scale to 2048.0
epoch 1: dev_f1=0.28571428571428575, f1=0.30000000000000004, best_f1=0.30000000000000004
step: 0, loss: 2.027850866317749
epoch 2: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.72965407371521
epoch 3: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.015202522277832
epoch 4: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.5761911273002625
epoch 5: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.439455270767212
epoch 6: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.7271339893341064
epoch 7: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8637082576751709
epoch 8: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.1514854431152344
epoch 9: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.863682746887207
epoch 10: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.7270889282226562
epoch 11: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.863652765750885
epoch 12: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.1514408588409424
epoch 13: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.15143883228302
epoch 14: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.0148658752441406
epoch 15: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.5904781818389893
epoch 16: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.0148520469665527
epoch 17: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.01485013961792
epoch 18: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.4391952753067017
epoch 19: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.7270371913909912
epoch 20: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.4392175674438477
epoch 21: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.4392108917236328
epoch 22: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.0148367881774902
epoch 23: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.302647113800049
epoch 24: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8635783195495605
epoch 25: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.5757505297660828
epoch 26: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.5757474303245544
epoch 27: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.4391957521438599
epoch 28: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.0148279666900635
epoch 29: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.0148282051086426
epoch 30: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.4392008781433105
epoch 31: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8635559678077698
epoch 32: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8635714054107666
epoch 33: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 2.0148158073425293
epoch 34: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8635637760162354
epoch 35: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.5757399201393127
epoch 36: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8635669946670532
epoch 37: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.1513622999191284
epoch 38: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 0.8635590076446533
epoch 39: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004
step: 0, loss: 1.7269994020462036
epoch 40: dev_f1=0.2666666666666667, f1=0.2666666666666667, best_f1=0.30000000000000004

(most of the times it was 0.2666666666666667 from the first epoch)

Thank you for the help.