Continual pre-training from an initial checkpoint with MLM and NSP

I’m trying to further pre-train a language model (BERT here) not from scratch but from an initial checkpoint using my own data. My goal is to later use these further pre-trained models for fine-tuning on some downstream tasks (I have no issue with the fine-tuning part). For the pre-training, I want to use both Masked Language Modeling (MLM) and Next Sentence Prediction (NSP) heads (the same way that BERT is pre-trained where the model’s total loss is the sum of MLM loss and NSP loss). My data is stored in a text file following the standard format for BERT input (each document has multiple sentences separated by newlines and documents are separated by an empty line):


sentence 1.1

sentence 1.2

empty line

sentence 2.1

sentence 2.2

I have two specific questions and I appreciate any feedback:

  1. I have some trouble finding the right function/script in the transformers library for such a purpose. As far as I understand, all the scripts for language modeling only use MLM for pretraining (correct me if I’m wrong.) I wonder if I should use BertForPreTraining for this purpose?

  2. Assuming I should use BertForPreTraining, I wonder how I should prepare my data for this model. I’m looking for the right object or data type/format and the right way of tokenizing my input data so that it’s suitable both for MLM and NSP.

2 Likes

Hi.

I want to do exactly same as you. Did you find any answer out there? How were things going on with your approach? I appreciate any advice that let me avoid any headache.

Thanks in advance.

BertForPreTraining has both heads (MLM and NSP) so you’re correct with your assumption.

I recently found this script which I used to train both MLM and NSP heads. I ultimately had one document which contained multiple paragraphs and each paragraph pertained to one topic. I’m using each paragraph as a separate context when using extractive QA. This ultimately gave me a list of strings where each string was a paragraph (ie. multiple sentences). This was stored inside the ‘text’ variable.

NSP Training Data & Labels

sentence_a = []
sentence_b = []
label = []

bag_of_sentences = [sentence.strip() for paragraph in text for sentence in paragraph.split('.') if sentence != '']
bag_size = len(bag_of_sentences)

for paragraph in text:
  sentences = [sentence.strip() for sentence in paragraph.split('.')]
  num_sentence = len(sentences)
  if num_sentence > 1: # if there is more than one sentence in the paragraph
    start = random.randint(0, num_sentence-2) # select a random integer b/w 0 and len(sentences-2)
    # This is "IsNext" sentence
    # selects a random float b/w 0-1 and if that float is greater than 0.5 (ie. 50%) append sentence_a w/a random sentence from the paragraph then append the next paragraph to sentence_b)
    if random.random() > 0.5: 
      sentence_a.append(sentences[start]) 
      sentence_b.append(sentences[start+1])
      label.append(0) # the label for IsNext = 0
    # IsNotNext append
    else:
      # This is IsNotNext sentence
      index = random.randint(0, bag_size-1) 
      sentence_a.append(sentences[start]) # append a random sentence from the paragraph to sentence_a but append a random sentence from the entire text to sentence_b
      sentence_b.append(bag_of_sentences[index])
      label.append(1) #the labels ofr IsNotNext = 1

from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
inputs = tokenizer(sentence_a, sentence_b, return_tensors='pt', 
                   max_length=512, truncation=True, padding='max_length')

# Creating labels for NSP
inputs['next_sentence_label'] = torch.LongTensor([label]).T

Creating 15% Random Input_Ids Masks for MLM

inputs['labels'] = inputs.input_ids.detach().clone()
rand = torch.rand(inputs.input_ids.shape)
mask_arr = (rand < 0.15) * (inputs.input_ids != 101) * (inputs.input_ids != 102) * (inputs.input_ids != 0)

selection = []
for i in range(inputs.input_ids.shape[0]):
  selection.append(torch.flatten(mask_arr[i].nonzero()).tolist())

for i in range(inputs.input_ids.shape[0]):
  inputs.input_ids[i, selection[i]] = 103

This will give you the training data required for further pre-train a bert model. I just used the ‘bert-base-uncased’ model.

My question for you is how do I further pre-train this model from a chosen checkpoint?

@kmysiak About your question, my understanding is that if you load a model using from_pretrained method given either a model name or path to your checkpoint, the model will be trained starting from that checkpoint. One thing to keep in mind (you may already know this) is that there is also a possibility that your further pretraining causes the model to forget some of the knowledge it already has. That is something you need to test and run experiments to figure out.

Thanks for validating my assumption. I will have to compare performance with and without continual pre-training.