Loss exploding/increasing in pretraining

I’ve been trying to train a customized version of GPT-NeoX on the MiniPile dataset. My training code is fairly minimal, with the only oddities of using ReLU^2 for activation, T5’s tokenizer, and tied embeddings - since I’m comparing this model to a non-transformer model that uses these. I’m training on 8 GPUs.

from transformers import AutoTokenizer, GPTNeoXConfig, GPTNeoXForCausalLM, DataCollatorForLanguageModeling
from datasets import load_dataset

tokenizer = AutoTokenizer.from_pretrained("t5-small")

cfg = GPTNeoXConfig(
    vocab_size = len(tokenizer),
    hidden_size = 768,
    intermediate_size = 768*4,
    num_hidden_layers = 11,
    num_attention_heads = 12,
    hidden_act = "relu2",
    max_position_embeddings = 1024,
    tie_word_embeddings = True
)

model =  GPTNeoXForCausalLM(cfg)

ds = load_dataset("JeanKaddour/minipile", split="train", cache_dir="/workspace/hf_cache")
ds_val = load_dataset("JeanKaddour/minipile", split="validation", cache_dir="/workspace/hf_cache")

x_n_max = 1024

def tokenize_fn(x):
    return tokenizer(x['text'], max_length = x_n_max, truncation = True)

toks = ds.map(tokenize_fn, batched = True, remove_columns=["text"]).shuffle(seed=42)
toks_val = ds_val.map(tokenize_fn, batched = True, remove_columns=["text"])

data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False, return_tensors='pt')

from transformers import Trainer, TrainingArguments
import wandb
wandb.login()

pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Parameters: {pytorch_total_params/1000000}M")

args = TrainingArguments(
    output_dir="gptx",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy="steps",
    eval_steps = 100,
    logging_steps = 100,
    gradient_accumulation_steps=1,
    num_train_epochs=1,
    lr_scheduler_type="cosine",
    warmup_steps = 1000,
    bf16=True,
    report_to="wandb",
    load_best_model_at_end=True
)

trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    data_collator=data_collator,
    train_dataset=toks,
    eval_dataset=toks_val
)

trainer.train()
wandb.finish()
trainer.save_model("gptx")

However, although the loss initially decreases, after a few hundred steps it starts increasing again:

Step 	Training Loss 	Validation Loss
100 	9.824500 	9.099127
200 	8.797100 	8.515280
300 	8.215400 	7.917030
400 	7.648100 	7.450994
500 	7.399400 	7.420805
600 	7.374300 	7.341578
700 	7.382400 	7.465085
800 	7.594300 	7.760824
900 	7.914200 	8.104758
1000 	8.256700 	8.487398
1100 	8.696600 	8.936168
1200 	9.142600 	9.377343

When I once left it, the loss continued to increase all the way past 11.5 - significantly worse than even the accuracy of uniform random token choices. The same occurs when I use LLaMA with a similar configuration instead of GPT-NeoX, when I use the default activation function instead of relu2, when I untie the embeddings, and under varied learning rates. (Hence, the problem is not the unsupported use of relu2) e.g.:

cfg = GPTNeoXConfig(
    vocab_size = len(tokenizer),
    hidden_size = 768,
    intermediate_size = 768*4,
    num_hidden_layers = 12,
    num_attention_heads = 12,
    tie_word_embeddings = True,
    bos_token_id = tokenizer.bos_token_id,
    eos_token_id = tokenizer.eos_token_id,
    tokenizer_class = "T5TokenizerFast"
)

yields

100 	9.347200 	8.959894
200 	8.786000 	8.624897
300 	8.459400 	8.292471
400 	8.108800 	7.929269
500 	7.737300 	7.550851
600 	7.386000 	7.236368
700 	7.134700 	7.065014
800 	7.073900 	7.122225
900 	7.251800 	7.446449
1000 	7.621200 	7.849643
1100 	8.031700 	8.260697
1200 	8.493300 	8.792882
1300 	9.046300 	9.347054
1400 	9.596400 	9.945238
1500 	10.211000 	10.569137
1600 	10.859000 	11.239925
1700 	11.563300 	11.997416
1800 	12.349400 	12.686112

In contrast, a custom Jax model (non-transformer) gets to 3.6 on the dataset without trouble.

Am I doing something wrong? I can’t seem to find anything obviously incorrect, but perhaps someone here has a better idea of what’s going on?

Update: It was the T5 tokenizer, seems it doesn’t play nicely with other models.