The loss plateau of pratraining Bert using run_mlm.py

Hi, I am trying to pretrain a Bert model from scratch using BookCorpusOpen and Wikipedia by run_mlm.py. One strange thing I note is that the loss curves always present two decreases. As the following figure shows (the blue curve), the loss curve first drops quickly and reaches a plateau. After a while, it drops again and converges. (These two are not convergent, and I show them because they are in the same figure and easy to compare. I do train other curves to converge.) I thought the normal curve should drop smoothly. Actually, I have tested different hyper-parameter settings, such as the learning rate, batch size, warm-up, and dataset. They all have the loss plateua, long or short. Finally, I find pre-processing the dataset with the “line_by_line” parameter mitigates the plateau problem, as the green curve shows. Although it is still not smooth but matches more what I thought it should be. Does the “line_by_line” parameter influence so much? I wonder whether this is a common problem or I have a problem with my hyper-parameters?
image

The model includes 6 layers and the hyper-parameters I used are:

python src/run_mlm.py \
    --model_type bert \
    --tokenizer_name bert-base-uncased \
    --config_name model/bert_layer6_512/config.json \
    --do_train \
    --seed 42 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 16 \
    --learning_rate=1e-4\
    --num_train_epochs 25 \
    --warmup_steps 8000 \
    --line_by_line \ # this parameter effects significantly
1 Like

I met the exact same problem. I thought it was because initial lr is too large, so I decrease it to 2.6e-4, but it still plateaus around 6. Did you find the reason for this?

I also tried some hyper-parameters and have no idea of the exact reason. But one thing I noticed is that the dataset influences the period of plateau. If you use small dataset like wikitext or processe a large dataset using 'line_by_line" (one choice in the script of run_mlm.py"), only a small plateau will occur🤔.

I also met the exact same problem, and also witnessed that using the line_by_line argument resolves it.

I inspected the outputs of the data collator, and both methods seem to work properly, thus the plateau is probably not the result of a bug. Maybe line_by_line enables shorter sentences to be seen by the model which helps the model learn faster, similar to https://aclanthology.org/2021.ranlp-1.112.pdf ?