Failed to train Llama model

I am currently having a problem while training the Llama model. Even though I choose to train this model(~1.5M parameters) on a very small dataset(5 samples only). This model is unable to overfit the data.

This is the Llama model I created:

class Llama(nn.Module):
    def __init__(self):
        super().__init__()
        config = LlamaConfig(
            vocab_size=NUM_VOCAB,
            hidden_size=EMBEDING_SIZE,
            intermediate_size=512,
            num_hidden_layers=NUM_LAYER,
            num_attention_heads=NUM_HEAD,
            hidden_act="silu",
            max_position_embeddings=MAX_SEQUENCE_LENGTH,
            initializer_range=0.02,
            rms_norm_eps=1e-6,
            use_cache=True,
            pad_token_id=0,
            bos_token_id=0,
            eos_token_id=0,
            tie_word_embeddings=False,
        )

        self.main = LlamaForCausalLM(config)

    def forward(self, inputs):
        return self.main(inputs).logits
    
model = Llama()

I tried to train this model on my own data with this function:

def train_model(model, train_loader: DataLoader):
    global min_loss
    model.to(DEVICE)
    criteria = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), LEARNING_RATE)

    for epoch in range(EPOCH):
        model.train()
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs = inputs.to(DEVICE)
            targets = targets.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(inputs)
            
            important_index = torch.tensor([([i for i, val in enumerate(seq) if val == 0]+[MAX_SEQUENCE_LENGTH]*3)[2] for seq in targets]).max().item()

            loss = criteria(outputs[:,:important_index,:].permute(0,2,1), targets[:,:important_index])

            loss.backward()
            running_loss += loss.item()
            optimizer.step()

        epoch_loss = running_loss / len(train_loader)
        if min_loss > epoch_loss:
            torch.save(model.state_dict(), MODEL_PATH)
            min_loss = epoch_loss
            with open(LOG_FILE, 'w') as wf:
                wf.write(str(min_loss))
                wf.close()
        
        print(f"Epoch {epoch+1}/{EPOCH} - Loss: {epoch_loss}")

This is an example of input tensors:

tensor([[    0,    88,   159,    52,   861, 20093,  4063, 14365, 20868, 21621,
          2731, 21621, 10405, 12979,   130,     0],
        [    0,   171,   155,    52,   104, 20977,    76,   104,   209,    52,
           861,    10, 15781,   130, 11618,     0]]) 

and ground truth tensors:

tensor([[  0, 876,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0],
        [  0, 875,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,
           0,   0]])

After a few times of training, I think that the model tends to look for some start of each sample only and ignore the remaining ones. For example, in this tensor

[    0,    88,   159,    52,   861, 20093,  4063, 14365, 20868, 21621, 2731, 21621, 10405, 12979, 130,     0]

it seems like the model uses the first 2 numbers of the tensor to calculate(0 and 88 only), which makes it unable to classify the difference between 2 sentences starting the same way.
Is there anything I don’t know about Llama(structure or inside code) and how can I get over this problem?
Thanks in advance.

Also, my inputs to the model is (batch, sequence ids), and targets is the same (batch, sequence ids). The outputs of the model(in my opinion) is (batch, sequences ids, feature) with feature is the vocab list. I think this information might help as when I try training on the decoder-only models with the support of the x-transformer library, I have the same problem.