Training Loss Sudden Spike After 8 Hours of pre-training a BERT Model

(issue was first posted in github and i was advised to post it here)
I am pretraining a bert model. I did it for bert base and results were very good and train loss got to less that 2.8 for my dataset (after 5 epochs in 10 days):
image

but when I tried for bert Large the loss gets stuck around 8.0 in less than an hour and didnt change even after 10 hours. (same train args as bert base but batch size is halved to fit)
I tried different things like warmup and different Learning Rates to no avail. The GPU memory usage for bert large was around 39GB (compared to 33GB for bert base with 32 batch size)
image

So I thought maybe if I make the model a bit smaller it would work. I changed the number of layers to 20 (from 24) and it worked fine for a while but not even midway through the first epoch it spiked back to the loss of 8.0. I have never seen such behavior in transformer models.

image

the learning rate graph looks completely normal. this just looks like an infamous case of exploding gradient i suppose, but I haven’t heard of them happening in transformer models.|

System Info:

os: windows server 2019
gpu: A100 40Gb
RAM: 128 GB

python 3.10.11
transformers 4.31.0
pytorch 2.0.1+cu118

config_large = BertConfig(vocab_size=50_000,hidden_size= 1024,  intermediate_size= 4096,
num_attention_heads= 16,  num_hidden_layers= 20,)
model_large = BertForMaskedLM(config=config_large)

training_args = TrainingArguments(
    output_dir="./mymodel",
    overwrite_output_dir=False, # changed
    num_train_epochs=20,
    per_gpu_train_batch_size=16,
    #
    logging_steps = 0.00001,
    save_strategy = 'steps',
    save_steps= 0.008,
    save_total_limit=20,
    evaluation_strategy='steps',
    eval_steps = 0.008,
    #
    learning_rate=5e-5,
    warmup_steps= 20000,
    #
    tf32 = True,
    optim = "adamw_torch_fused",
    group_by_length = True,
    #
    prediction_loss_only=True,
    #
    hub_model_id = 'mymodel',
    push_to_hub = True,
    hub_strategy = 'every_save',
    hub_private_repo = True,
)

trainer = Trainer(
    model=model_large,
    args=training_args,
    data_collator=data_collator_wwm,
    train_dataset=data['train'],
    eval_dataset=data['test']
)