Albert LM on WikiText2

I am currently trying to train an Albert language model on the WikiText2 dataset from scratch. The issue I am having is that during training, the model converges to a loss of ~6 with the model always outputting the same token.

Here is my code for training,

import transformers
from torch.utils.data import DataLoader, RandomSampler

#-------Model setup--------
device = torch.device('cuda')
config = transformers.AlbertConfig(
                            hidden_size = 768,
                            intermediate_size = 3072,
                            num_attention_heads = 12)
model = transformers.AlbertForMaskedLM(config).to(device)
tokenizer = transformers.AlbertTokenizer.from_pretrained('albert-base-v2')
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)


#-------Data Setup-------
data = transformers.TextDataset(tokenizer=tokenizer, 
                                file_path="./wikitext-2/wiki.train.tokens", 
                                block_size=128)

data_collator = transformers.DataCollatorForLanguageModeling(tokenizer=tokenizer, 
                                                             mlm_probability=0.15, 
      
                                                   mlm=True)
sampler = RandomSampler(data)
data_loader = DataLoader(data, 
                         batch_size=8, 
                         sampler=sampler, 
                         collate_fn=data_collator)


#-------Training Loop-------
n_epochs=20
for epoch in range(1, n_epochs+1):
    for i, batch in enumerate(data_loader):
        loss, _ = model(batch['input_ids'].to(device), 
        labels=batch['labels'].to(device))
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

After 20 epochs, the model is at a loss of ~6. When I then input text with a masked token it returns the following.

text = "Hello what is [MASK] name"
text = tokenizer(text, return_tensors="pt")
labels = torch.tensor([-100, -100, -100, -100, 154, -100, -100])
loss, output = model(text, labels=labels)

print(loss, output)

(tensor(9.5703, device='cuda:0', grad_fn=<NllLossBackward>),
 tensor([[[-8.9541,  5.2743, -6.8293,  ..., -8.8524, -4.8818, -9.2188],
          [-8.9541,  5.2743, -6.8293,  ..., -8.8524, -4.8818, -9.2188],
          [-8.9541,  5.2743, -6.8293,  ..., -8.8524, -4.8818, -9.2188],
          ...,
          [-8.9541,  5.2743, -6.8293,  ..., -8.8524, -4.8818, -9.2188],
          [-8.9541,  5.2743, -6.8293,  ..., -8.8524, -4.8818, -9.2188],
          [-8.9541,  5.2743, -6.8293,  ..., -8.8524, -4.8818, -9.2188]]],
        device='cuda:0', grad_fn=<AddBackward0>))

#Prints tokens with top scores
print(torch.max(out, dim=2))

torch.return_types.max(
values=tensor([[6.5202, 6.5202, 6.5202, 6.5202, 6.5202, 6.5202, 6.5202]],
       device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([[13, 13, 13, 13, 13, 13, 13]], device='cuda:0'))

As can be seen, this returns the same score for each token, with 13 corresponding to “”, the ‘blank’ token.

I obtain the same results when training using transformers.Trainer

I am still fairly new to the Transformers API, hence I am not sure what is going wrong here. Any help would be appreciated.