Number of epochs in pre-training BERT


In the BERT paper, it says:

We train with batch size of 256 sequences (256 sequences * 512 tokens = 128,000 tokens/batch) for 1,000,000 steps, which is approximately 40 epochs over the 3.3 billion word corpus.

How does this equation work?

What is the unit “word” in “3.3 billion word corpus”? Is it the same as output from wc -w command on the entire text corpus? If this unit is a raw token, is there a guarantee that the number of “words” in the entire corpus matches the number tokens in the whole dataset after data preparation with (assume duplicate factor is set to 1)?

According to this line of code, in a training instance, some WordPiece tokens in the sequence will be dropped from the front or the back if the sequence is longer than max sequence length. Is this taken into account?

If I understand this function correctly, when the next segment gets randomly chosen, the segment that was there before it was swapped with this randomly chosen segment will be “put back.” (here) Does this mean that we have more tokens in total because of these randomly chosen segments?

(I opened an issue on Google’s repository, but I wanted to ask this in this community as well.)

1 Like

@go-inoue Did you find an answer to this question?

My best guess:

1 000 000 steps equals approx. 40 epochs -> (1*e6)/40=25 000 steps per epoch.
Each step (iteration) is using a batch size of 128 000 tokens -> 25 000 * 128 000= 3.2 billion tokens in each epoch.

One epoch is equal to one full iteration over the training data. In other words the training data contains approx. 3.2 billion tokens.

I would expect the number of tokens to be higher than the number of words in the training data given that full stops, commas etc are separate tokens, and words sometimes are split into several tokens using the BERT-tokenizer.

Could it be that the 3.4 bn word corpus is split into training, validation and test-data? I’m not even sure you have a split of train/val/test data during pre-training of BERT, given it is unsupervised? Some kind of cross-validation of all data would seem to make more sense?