Further train bert with next sentence prediction head using tensorflow

I’m trying to train TFBertForNextSentencePrediction on my own corpus, not from scratch, but rather taking the existing bert model with only a next sentence prediction head and further train it on a specific cuprous of text (pairs of sentences). Then I want to use the model I trained to be able to extract sentence embeddings from the last hidden state for other texts.

Currently the problem I encounter is that after I train the keras model I am not able to extract the hidden states of the last layer before the next sentence prediction head.

Below is the code. Here I only train it on a few sentences just to make sure the code works. Any help will be greatly appreciated.

Thanks, Ayala

import numpy as np
import pandas as pd
import tensorflow as tf
from datetime import datetime
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing import sequence
from tensorflow.keras.callbacks import ModelCheckpoint
from transformers import BertTokenizer, PreTrainedTokenizer, BertConfig, TFBertForNextSentencePrediction
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score


PRETRAINED_MODEL = 'bert-base-uncased'

# set paths and file names
time_stamp = str(datetime.now().year) + "_" + str(datetime.now().month) + "_" + str(datetime.now().day) + "_" + \
                     str(datetime.now().hour) + "_" + str(datetime.now().minute)
model_name = "pretrained_nsp_model"
model_dir_data = model_name + "_" + time_stamp
model_fn = model_dir_data + ".h5"
base_path = os.path.dirname(__file__)
input_path = os.path.join(base_path, "input_data")
output_path = os.path.join(base_path, "output_models")
model_path = os.path.join(output_path, model_dir_data)
if not os.path.exists(model_path):
    os.makedirs(model_path)

# set model checkpoint
checkpoint = ModelCheckpoint(os.path.join(model_path, model_fn), monitor="val_loss", verbose=1, save_best_only=True,
                             save_weights_only=True, mode="min")

# read data
max_length = 512

def get_tokenizer(pretrained_model_name):
    tokenizer = BertTokenizer.from_pretrained(pretrained_model_name)
    return tokenizer

def tokenize_nsp_data(A, B, max_length):
    data_inputs = tokenizer(A, B, add_special_tokens=True, max_length=max_length, truncation=True,
                             pad_to_max_length=True, return_attention_mask=True,
                             return_tensors="tf")
    return data_inputs

def get_data_features(data_inputs, max_length):
    data_features = {}
    for key in data_inputs:
        data_features[key] = sequence.pad_sequences(data_inputs[key], maxlen=max_length, truncating="post",
                                                          padding="post", value=0)
    return data_features

def get_transformer_model(transformer_model_name):
    # get transformer model
    config = BertConfig(output_attentions=True)
    config.output_hidden_states = True
    config.return_dict = True
    transformer_model = TFBertForNextSentencePrediction.from_pretrained(transformer_model_name, config=config)
    return transformer_model

def get_keras_model(transformer_model):
    # get keras model
    input_ids = tf.keras.layers.Input(shape=(max_length,), name='input_ids', dtype='int32')
    input_masks_ids = tf.keras.layers.Input(shape=(max_length,), name='attention_mask', dtype='int32')
    token_type_ids = tf.keras.layers.Input(shape=(max_length,), name='token_type_ids', dtype='int32')
    X = transformer_model({'input_ids': input_ids, 'attention_mask': input_masks_ids, 'token_type_ids': token_type_ids})[0]
    model = tf.keras.Model(inputs=[input_ids, input_masks_ids, token_type_ids], outputs=X)
    model.summary()
    model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
                  optimizer=tf.optimizers.Adam(learning_rate=0.00005), metrics=['accuracy'])
    return model

def get_metrices(true_values, pred_values):
    cm = confusion_matrix(true_values, pred_values)
    acc_score = accuracy_score(true_values, pred_values)
    f1 = f1_score(true_values, pred_values, average="binary")
    precision = precision_score(true_values, pred_values, average="binary")
    recall = recall_score(true_values, pred_values, average="binary")
    metrices = {'confusion_matrix': cm,
                'acc_score': acc_score,
                'f1': f1,
                'precision': precision,
                'recall': recall
                }
    for k, v in metrices.items():
        print(k, ':\n', v)
    return metrices

# get tokenizer
tokenizer = get_tokenizer(PRETRAINED_MODEL)

# train 
prompt = ["Hello", "Hello", "Hello", "Hello"]
next_sentence = ["How are you?", "Pizza", "How are you?", "Pizza"]
train_labels = [0, 1, 0, 1]
train_labels = to_categorical(train_labels)
train_inputs = tokenize_nsp_data(prompt, next_sentence, max_length)
train_data_features = get_data_features(train_inputs, max_length)

# val
prompt = ["Hello", "Hello", "Hello", "Hello"]
next_sentence = ["How are you?", "Pizza", "How are you?", "Pizza"]
val_labels = [0, 1, 0, 1]
val_labels = to_categorical(val_labels)
val_inputs = tokenize_nsp_data(prompt, next_sentence, max_length)
val_data_features = get_data_features(val_inputs, max_length)

# get transformer model
transformer_model = get_transformer_model(PRETRAINED_MODEL)

# get keras model
model = get_keras_model(transformer_model)

callback_list = []
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=4, min_delta=0.005, verbose=1)
callback_list.append(early_stop)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=2, epsilon=0.001)
callback_list.append(reduce_lr)
callback_list.append(checkpoint)

history = model.fit([train_data_features['input_ids'], train_data_features['attention_mask'],
                     train_data_features['token_type_ids']], np.array(train_labels), batch_size=2, epochs=3,
                    validation_data=([val_data_features['input_ids'], val_data_features['attention_mask'],
                                      val_data_features['token_type_ids']], np.array(val_labels)), verbose=1,
                    callbacks=callback_list)

model.layers[3].save_pretrained(model_path)  # need to save this and make sure i can get the hidden states

##  predict
# load model
transformer_model = get_transformer_model(model_path)
model = get_keras_model(transformer_model)
model.summary()
model.load_weights(os.path.join(model_path, model_fn))


# test
prompt = ["Hello", "Hello"]
next_sentence = ["How are you?", "Pizza"]
test_labels = [0, 1]
test_df = pd.DataFrame({'A': prompt, 'B': next_sentence, 'label': test_labels})
test_labels = to_categorical(val_labels)
test_inputs = tokenize_nsp_data(prompt, next_sentence, max_length)
test_data_features = get_data_features(test_inputs, max_length)

# predict
pred_test = model.predict([test_data_features['input_ids'], test_data_features['attention_mask'], test_data_features['token_type_ids']])
preds = tf.keras.activations.softmax(tf.convert_to_tensor(pred_test)).numpy()

true_test = test_df['label'].to_list()
pred_test = [1 if p[1] > 0.5 else 0 for p in preds]
test_df['pred_val'] = pred_test

metrices = get_metrices(true_test, pred_test)

I am also attaching a picture from the debugging mode in which I try (with no success) to view the hidden state. The problem is I am not able to see and save the transform model I trained and view the embeddings of the last hidden state. I tried converting the KerasTensor to numpy array but without success.

Hi @ayalaall, did further training on NSP task on a custom corpus improve performance of the embeddings?

My use case: Further train bert-base-uncased model on custom text corpus, then use it to get sentence embeddings

I tried just using the MLM task for training but the embedding performance was even poorer than the bert-base model, is the NSP task very important when the use case is embedding generation?

Hi @pk1203,

I ended up using another approach: using the sentence-transformers library (pytorch based) to further train the bert model on my corpus.

Hi,
I’m getting back to this now. Any luck with NSP? Are you using tensorflow or pytorch? Did you try adding learning rate scheduler to the optimizer such as PolynomialDecay? Fine-tuning a pretrained model - Hugging Face Course