BERT for NextSentencePrediction train and inference problem, thanks

Dear all, I am quite new to HuggingFace but familiar with TF and Torch. I tried to use BERT NSP for my problem on next question prediction. That is, when I have the first question and I want to predict the next question. I have built my scripts following some recipe, as following. I tried to train the model, and the training process is also attached below. I know my model is overfitting, that is the next issue I will solve. My first question is that it seems the model converges on the train set, in terms of loss and accuracy. However, I have tried to use the trained model on the train set again, i.e., the inference method. I got the following result:
Average loss: 0.0040, accuracy: 0.0005

If I change the train set to val set, the result is: Average loss: 4.8383, accuracy: 0.0000.

However, it seems the accuracy in training is almost close to 1, why did I get such low accuracy when doing inference on the train set? Is there something wrong I did?

import os
import numpy as np
from pickle import load
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.callbacks import ModelCheckpoint
from transformers import BertTokenizer, TFBertForNextSentencePrediction
from transformers.models.bert.modeling_tf_bert import TFBertForNextSentencePrediction

from tensorflow.keras.losses import BinaryCrossentropy
import tensorflow.keras.backend as K


class BERT_Trainer:
    def __init__(self, output_folder,
                 pretrain_model_path,
                 bert_config_file,
                 vocab_file,
                 special_tokens_map_file,
                 tokenizer_config_file):
        pretrained_model = 'nb-bert-base'
        self.pretrain_model_name = pretrained_model
        self.pretrain_model_path = pretrain_model_path
        self.bert_config_file = bert_config_file
        self.max_length = 128
        self.output_folder = output_folder
        self.init_tokenizer(vocab_file, special_tokens_map_file, tokenizer_config_file)

    def init_tokenizer(self, vocab_file, special_tokens_map_file, tokenizer_config_file):
        vocab_files = {
            'vocab_file': vocab_file,
            'special_tokens_map_file': special_tokens_map_file,
            'tokenizer_config_file': tokenizer_config_file
        }
        tokenizer = BertTokenizer._from_pretrained(vocab_files, self.pretrain_model_name, init_configuration={})
        self.tokenizer = tokenizer

    def set_pretrain_model(self, pretrain_model_path):
        pretrain_model = TFBertForNextSentencePrediction.from_pretrained(pretrain_model_path,
                                                                         config=self.bert_config_file)
        return pretrain_model

    def set_model(self, training):
        pretrain_model = self.set_pretrain_model(self.pretrain_model_path)
        input_ids = tf.keras.layers.Input(shape=(self.max_length,), name='input_ids', dtype='int32')
        input_masks_ids = tf.keras.layers.Input(shape=(self.max_length,), name='attention_mask', dtype='int32')
        token_type_ids = tf.keras.layers.Input(shape=(self.max_length,), name='token_type_ids', dtype='int32')
        X = pretrain_model(
            {'input_ids': input_ids, 'attention_mask': input_masks_ids, 'token_type_ids': token_type_ids})
        if training:
            X = X[0]
        model = tf.keras.Model(inputs=[input_ids, input_masks_ids, token_type_ids], outputs=X)
        model.summary()
        model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                      optimizer=tf.optimizers.Adam(learning_rate=0.00001), metrics=['accuracy'])
        return model

    def get_data_features(self, data_inputs):
        data_features = {}
        for key in data_inputs:
            data_features[key] = sequence.pad_sequences(data_inputs[key],
                                                        maxlen=self.max_length,
                                                        truncating="post",
                                                        padding="post",
                                                        value=0)
        return data_features

    def load_data(self, data_file):
        with open(data_file, 'rb') as f:
            first_questions, next_questions, labels = load(f)
        data_inputs = self.tokenizer(
            first_questions,
            next_questions,
            add_special_tokens=True,
            max_length=self.max_length,
            truncation=True,
            pad_to_max_length=True,
            return_attention_mask=True,
            return_tensors='tf'
        )
        labels = to_categorical(labels)
        data_features = self.get_data_features(data_inputs)
        return data_features, labels

    def train(self, train_data_file, val_data_file, multi_gpu=False, gpu_idx=0):
        callbacks = []
        early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=40, min_delta=0.005, verbose=1)
        callbacks.append(early_stop)

        reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=40, epsilon=0.001)
        callbacks.append(reduce_lr)

        checkpoint = ModelCheckpoint(os.path.join(self.output_folder, 'bert_nsp.h5'),
                                     monitor="val_loss",
                                     verbose=1,
                                     save_best_only=True,
                                     save_weights_only=True,
                                     mode="min")
        callbacks.append(checkpoint)

        if multi_gpu:
            strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
            with strategy.scope():
                model = self.set_model(training=True)
        else:
            gpus = tf.config.experimental.list_physical_devices('GPU')
            tf.config.experimental.set_visible_devices(gpus[gpu_idx], 'GPU')
            model = self.set_model(training=True)

        model.run_eagerly = True

        train_data_features, train_labels = self.load_data(train_data_file)
        val_data_features, val_labels = self.load_data(val_data_file)

        history = model.fit([train_data_features['input_ids'],
                             train_data_features['attention_mask'],
                             train_data_features['token_type_ids']],
                            np.array(train_labels),
                            batch_size=32,
                            epochs=20,
                            validation_data=([val_data_features['input_ids'],
                                              val_data_features['attention_mask'],
                                              val_data_features['token_type_ids']],
                                             np.array(val_labels)),
                            verbose=1,
                            callbacks=callbacks)

    def inference(self, data_file):
        model = self.set_model(training=True)
        weights_file = os.path.join(self.output_folder, 'bert_nsp.h5')
        model.load_weights(weights_file)
        data_features, labels = self.load_data(data_file)

        loss = BinaryCrossentropy(from_logits=True)

        batch_size = 32
        total_loss = 0
        total_accu = 0
        for i in range(len(labels) // batch_size):
            output = model([data_features['input_ids'][i * batch_size: (i + 1) * batch_size, :],
                            data_features['attention_mask'][i * batch_size: (i + 1) * batch_size, :],
                            data_features['token_type_ids'][i * batch_size: (i + 1) * batch_size, :]])

            loss_value = loss(labels[i * batch_size: (i + 1) * batch_size, :], output)
            total_loss += loss_value
            accu = K.mean(K.equal(labels[i * batch_size: (i + 1) * batch_size, :], K.round(output)))
            total_accu += accu

        print('Average loss: {:.4f}, accuracy: {:.4f}'.format(total_loss / (len(labels) // batch_size),
                                                      total_accu / (len(labels) // batch_size)))


if __name__ == '__main__':
    output_folder = r'.\output'
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    pretrain_model_path = r'..\bert_pretrain_weights\bert_base_pretrain.h5'
    bert_config_file = r'..\bert_pretrain_weights\bert_base_config.json'
    vocab_file = r'..\bert_pretrain_weights\vocab_nb.txt'
    special_tokens_map_file = r'..\bert_pretrain_weights\special_tokens_map_nb.json'
    tokenizer_config_file = r'..\bert_pretrain_weights\tokenizer_config.json'

    trainer = BERT_Trainer(
        output_folder,
        pretrain_model_path,
        bert_config_file,
        vocab_file,
        special_tokens_map_file,
        tokenizer_config_file)
    train_data_file = r'..\data_files\train_data_nsp.pkl'
    val_data_file = r'..\data_files\val_data_nsp.pkl'
    
    trainer.train(train_data_file, val_data_file)

    trainer.inference(train_data_file)

Epoch 1/20
158/158 [==============================] - 52s 327ms/step - loss: 0.1747 - accuracy: 0.9510 - val_loss: 4.8420 - val_accuracy: 0.1615

Epoch 00001: val_loss improved from inf to 4.84197, saving model to .\output\bert_nsp.h5
Epoch 2/20
158/158 [==============================] - 53s 335ms/step - loss: 0.0053 - accuracy: 1.0000 - val_loss: 5.0008 - val_accuracy: 0.1615

Epoch 00002: val_loss did not improve from 4.84197
Epoch 3/20
42/158 [======>…] - ETA: 35s - loss: 0.0041 - accuracy: 1.0000