Finetuning Bert to adapt to the newly added class

I have a problem where I have trained a BertForSequenceClassification model to classify 5 classes. Now I want this model to retain its weight for the previous classes and get trained for the newly added 6th class. I have tried to freeze layer of bert and retrain but couldn’t succeed. Can anyone please help me with this.

Down below I am attaching my code for the Trainer.
def trainer(dataset_train, dataset_val, hyperparameter, number_of_labels):
model_report =

try:
    hyperparameters = hyperparameter.split("_")

    batch_size = int(hyperparameters[0])
    learning_rate = float(hyperparameters[1])
    epochs = int(hyperparameters[2])

    print("Hyperparameters set for BERT - batch size: %d, learning rate: %f, epoch: %d" % (batch_size, learning_rate, epochs))

    # Load and adjust the model
    model = BertForSequenceClassification.from_pretrained(os.getcwd(), num_labels=5)

    # Adjust the classifier for 6 classes
    original_classifier = model.classifier
    new_classifier = nn.Linear(original_classifier.in_features, 6)
    with torch.no_grad():
        new_classifier.weight[:5] = original_classifier.weight
        new_classifier.bias[:5] = original_classifier.bias
        new_classifier.weight[5].zero_()
        new_classifier.bias[5].zero_()
    model.classifier = new_classifier

    # Freeze the initial layers of BERT
    for name, param in model.bert.named_parameters():
        if 'layer.' in name:
            layer_num = int(name.split('layer.')[1].split('.')[0])
            if layer_num < 6:  # Adjust this number based on how many layers you want to freeze
                param.requires_grad = False
            else:
                param.requires_grad = True
        else:
            param.requires_grad = True

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    dataloader_train = DataLoader(dataset_train, sampler=RandomSampler(dataset_train), batch_size=batch_size)
    dataloader_validation = DataLoader(dataset_val, sampler=SequentialSampler(dataset_val), batch_size=batch_size)

    optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate, eps=1e-8)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=len(dataloader_train) * epochs)

    for epoch in tqdm(range(1, epochs + 1)):
        model.train()
        loss_train_total = 0

        progress_bar = tqdm(dataloader_train, desc='Epoch {:1d}'.format(epoch), leave=False, disable=False)

        for batch in progress_bar:
            model.zero_grad()
            batch = tuple(b.to(device) for b in batch)

            inputs = {'input_ids': batch[0], 'attention_mask': batch[1], 'labels': batch[2]}

            # Debugging statement: Print input shapes
            print(f"Input IDs shape: {inputs['input_ids'].shape}")
            print(f"Attention Mask shape: {inputs['attention_mask'].shape}")
            print(f"Labels shape: {inputs['labels'].shape}")

            outputs = model(**inputs)

            # Debugging statement: Print output shapes
            print(f"Logits shape: {outputs.logits.shape}")
            print(f"Expected shape: {inputs['labels'].shape[0]}, {number_of_labels}")

            loss = outputs.loss
            loss_train_total += loss.item()
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

            optimizer.step()
            scheduler.step()

            progress_bar.set_postfix({'training_loss': '{:.3f}'.format(loss.item() / len(batch))})

        if epoch == 5:
            model.save_pretrained(os.getcwd())

        tqdm.write(f'\nEpoch {epoch}')
        loss_train_avg = loss_train_total / len(dataloader_train)
        tqdm.write(f'Training loss: {loss_train_avg}')

        val_loss, predictions, true_vals = evaluate(dataloader_validation, model)
        val_f1 = f1_score_func(predictions, true_vals)
        tqdm.write(f'Validation loss: {val_loss}')
        tqdm.write(f'F1 Score (Weighted): {val_f1}')

        model_report.append([f'{batch_size}_{learning_rate}_{epoch}', loss_train_avg, val_loss, val_f1])
        print("Epoch: %d, Training loss: %f, Validation loss: %f, F1 Score: %f" % (epoch, loss_train_avg, val_loss, val_f1))

    df = pd.DataFrame(model_report, columns=['Model', 'training_loss', 'validation_loss', 'f1_score'])
    df.to_csv(f'report_bert.csv', index=False)

    print("Validation Started")
    Validator(dataloader_validation, model)

    print("BERT Training and Validation Completed")

except Exception as e:
    print("Error: %s, traceback: %s" % (e, traceback.format_exc()))