The loss plateau of pratraining Bert using run_mlm.py

Hi, I am trying to pretrain a Bert model from scratch using BookCorpusOpen and Wikipedia by run_mlm.py. One strange thing I note is that the loss curves always present two decreases. As the following figure shows (the blue curve), the loss curve first drops quickly and reaches a plateau. After a while, it drops again and converges. (These two are not convergent, and I show them because they are in the same figure and easy to compare. I do train other curves to converge.) I thought the normal curve should drop smoothly. Actually, I have tested different hyper-parameter settings, such as the learning rate, batch size, warm-up, and dataset. They all have the loss plateua, long or short. Finally, I find pre-processing the dataset with the “line_by_line” parameter mitigates the plateau problem, as the green curve shows. Although it is still not smooth but matches more what I thought it should be. Does the “line_by_line” parameter influence so much? I wonder whether this is a common problem or I have a problem with my hyper-parameters?
image

The model includes 6 layers and the hyper-parameters I used are:

python src/run_mlm.py \
    --model_type bert \
    --tokenizer_name bert-base-uncased \
    --config_name model/bert_layer6_512/config.json \
    --do_train \
    --seed 42 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 16 \
    --learning_rate=1e-4\
    --num_train_epochs 25 \
    --warmup_steps 8000 \
    --line_by_line \ # this parameter effects significantly
1 Like

I met the exact same problem. I thought it was because initial lr is too large, so I decrease it to 2.6e-4, but it still plateaus around 6. Did you find the reason for this?

I also tried some hyper-parameters and have no idea of the exact reason. But one thing I noticed is that the dataset influences the period of plateau. If you use small dataset like wikitext or processe a large dataset using 'line_by_line" (one choice in the script of run_mlm.py"), only a small plateau will occur🤔.

I also met the exact same problem, and also witnessed that using the line_by_line argument resolves it.

I inspected the outputs of the data collator, and both methods seem to work properly, thus the plateau is probably not the result of a bug. Maybe line_by_line enables shorter sentences to be seen by the model which helps the model learn faster, similar to https://aclanthology.org/2021.ranlp-1.112.pdf ?

My hypothesis is that the plateau is corresponding to the level of loss that model can achieve only relying on words frequency statistic, without any relation to context. And more time is needed for model to start to use context.