Can not make transfer learning

Hi, I am trying to simple BERT based text classifier. The problem is that my model is not training. Loss is jumping from 0.6 to 0.8. A number of epochs also do not have any impact. Changing the learning rate does not do much. What am I doing wrong? Adding the code below:

from typing import Callable
import unicodedata

from transformers import BertForSequenceClassification, BertModel
from transformers import AdamW
from transformers import BertTokenizer
from transformers import get_linear_schedule_with_warmup

import torch
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import pytorch_lightning as pl
from pytorch_lightning.metrics import functional as FM
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.metrics import (
    f1_score,
    accuracy_score,
    classification_report,
    confusion_matrix,
)

import numpy as np
import seaborn as sn
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd
class BERTDataset(Dataset):
    def __init__(self, source_file):
        self.data_df = source_file

    def __len__(self):
        return self.data_df.shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        label = self.data_df.iloc[idx, 1]
        text = self.data_df.iloc[idx, 0]
        return text, label
def build_preprocessing_pipeline(tokenizer) -> Callable:
    def collate_fn(sample):
        sample = np.array(sample)
        encoded_batch = tokenizer.batch_encode_plus(
            sample[:, 0], padding="max_length", truncation=True
        )
        inputs = {
            "input_ids": torch.tensor(encoded_batch["input_ids"]),
            "token_type_ids": torch.tensor(encoded_batch["token_type_ids"]),
            "attention_mask": torch.tensor(encoded_batch["attention_mask"]),
            "labels": torch.from_numpy(sample[:, 1].astype(float).astype(int)),
        }
        return inputs

    return collate_fn
class BertClassifier(pl.LightningModule):
    def __init__(
        self,
        bert_path: str,
        n_classes: int = 2,
    ) -> None:
        super().__init__()
        self.__tokenizer = BertTokenizer.from_pretrained(bert_path)
        self.bert = BertModel.from_pretrained(bert_path)
        self.drop = torch.nn.Dropout(p=0.3)
        self.linear1 = torch.nn.Linear(self.bert.config.hidden_size, n_classes)

        self.lr = 2e-5
        self.num_warmup_steps = 0
        self.num_training_steps: 100
        self.accuracy = pl.metrics.Accuracy()

    @property
    def max_length(self) -> int:
        return self.__max_length

    @property
    def tokenizer(self) -> BertTokenizer:
        return self.__tokenizer

    def freeze_bert(self) -> None:
        for param in self.bert.base_model.parameters():
            param.requires_grad = False

        print(
            f"Trainable layers count: {len([p for p in self.bert.parameters() if p.requires_grad])}"
        )

    def unfreeze_bert(self) -> None:
        for param in self.bert.base_model.parameters():
            param.requires_grad = True

        print(
            f"Trainable layers count: {len([p for p in self.bert.parameters() if p.requires_grad])}"
        )

    def configure_optimizers(
        self,
        lr: float = None,
        num_training_steps: int = None,
        num_warmup_steps: int = None,
    ):
        if lr:
            self.lr = lr
        if num_warmup_steps:
            self.num_warmup_steps = num_warmup_steps
        if num_training_steps:
            self.num_training_steps = num_training_steps

        no_decay = ["bias", "LayerNorm.weight"]

        optimizer_grouped_parameters = [
            {
                "params": [
                    p
                    for n, p in self.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.00,
            },
            {
                "params": [
                    p
                    for n, p in self.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay": 0.0,
            },
        ]
        
        optimizer = torch.optim.Adam(
            [p for p in self.parameters() if p.requires_grad], lr=self.lr, eps=1e-08
        )

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.num_warmup_steps,
            num_training_steps=self.num_training_steps - self.num_warmup_steps,
        )

        return [optimizer], [scheduler]

    def forward(self, batch):
        x = self.bert(**batch)
        x = self.linear1(x[1])
        return x

    def training_step(self, batch, batch_idx):
        labels = batch.pop("labels")
        y_hat = self(batch)
        loss = F.cross_entropy(y_hat, labels)
        acc = FM.accuracy(y_hat, labels)
        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc_step", acc, prog_bar=True)
        return loss

    def training_epoch_end(self, outs):
        self.log("train_acc_epoch", self.accuracy.compute())

    def validation_step(self, batch, batch_idx):
        labels = batch.pop("labels")
        y_hat = self(batch)
        loss = F.cross_entropy(y_hat, labels)
        acc = FM.accuracy(y_hat, labels)
        self.log("val_loss", loss)
        return loss

    def validation_epoch_end(self, outs):
        self.log("validation_acc_epoch", self.accuracy.compute())
def get_predictions(dataset, model):
    device = device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    preds = []
    labels = []
    for batch in tqdm(dataset):
        for key, _ in batch.items():
            batch[key] = batch[key].to(device)
        outputs = model(batch)
        _, predicted = torch.max(outputs[1], 1)
        preds.extend(predicted)
        labels.extend(batch["labels"])
    preds = [pred.item() for pred in preds]
    labels = [label.item() for label in labels]
    return preds, labels
def preprocess_text(input_text: str) -> str:
    input_text = (
        unicodedata.normalize("NFKD", input_text)
        .encode("ascii", "ignore")
        .decode("utf-8")
    )
    return input_text
train_df = pd.read_parquet("./data/imdb/train.parquet")
test_df = pd.read_parquet("./data/imdb/test.parquet")
train_df["text_preprocessed"] = train_df.text.apply(lambda x: preprocess_text(x))
test_df["text_preprocessed"] = test_df.text.apply(lambda x: preprocess_text(x))
TRAIN_SIZE = 0.85
BATCH_SIZE = 4
EPOCHS_NUM = 5
ACC_GRAD_BATCHES = 32 // BATCH_SIZE
STEPS_PER_EPOCH = int(train_df.shape[0] * TRAIN_SIZE) // 32
model = BertClassifier("bert-base-cased")
full_train_dataset = BERTDataset(train_df)
test_dataset = BERTDataset(test_df)
train_size = int(TRAIN_SIZE * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(
    full_train_dataset, [train_size, val_size]
)
dataloader = {
    "train": DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        num_workers=8,
        shuffle=True,
        collate_fn=build_preprocessing_pipeline(model.tokenizer),
    ),
    "val": DataLoader(
        val_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        collate_fn=build_preprocessing_pipeline(model.tokenizer),
    ),
    "test": DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=8,
        collate_fn=build_preprocessing_pipeline(model.tokenizer),
    ),
}
model.configure_optimizers(2e-5, 500, STEPS_PER_EPOCH)
model.freeze_bert()
logger = TensorBoardLogger("./logs", name="dbert")
trainer = pl.Trainer(
    gpus=[0],
    max_epochs=EPOCHS_NUM,
    logger=logger,
    accumulate_grad_batches=ACC_GRAD_BATCHES,
    track_grad_norm=2,
    log_gpu_memory=True,
)
trainer.fit(model, dataloader["train"], dataloader["val"])