Need help with attention mechanism attending to future tokens

import torch
import torch.nn as nn
import os
import warnings
import torch.nn.functional as F
import math

class DecoderSequential(nn.Sequential):
    @staticmethod
    def is_key_duplicate(dictionary, key_to_check):
            count = 0
            for key in dictionary:
                if key == key_to_check:
                    count += 1
                    if count > 1:
                        return True
            return False
    def forward(self, input_ids, **kwargs):
        for module in self:
            try:
                input = self.embedding(input_ids)
            except Exception as e:
                input = input_ids
                pass
            if "input_ids" in kwargs:
                kwargs.pop("input_ids", None)
            if isinstance(module, nn.Embedding):
                continue
            try:
                output = module(input_ids=input, **kwargs)
            except:
                output = module(input, **kwargs)
            try:
                input = output.last_hidden_state
            except:
                input = output[0]
            

        return output

class Attention(nn.Module):
    def __init__(self, n_head, d_model):
        super().__init__()
        assert d_model % n_head == 0
        self.n_head = n_head
        self.head_dim = d_model // n_head
        self.d_model = d_model
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
    def forward(self, q, k, v, attention_mask=None):
        bsz, q_len, _ = q.size()
        if attention_mask is None:
            attention_mask = self.generate_causal_mask(q)
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)
        q = q.view(bsz, q_len, self.n_head, self.head_dim).transpose(1,2)
        k = k.view(bsz, q_len, self.n_head, self.head_dim).transpose(1,2)
        v = v.view(bsz, q_len, self.n_head, self.head_dim).transpose(1,2)
        attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
        attn = attn + attention_mask
        attn = F.softmax(attn, dim=-1)
        attn = torch.matmul(attn, v).transpose(1, 2).contiguous()
        attn = attn.reshape(bsz, q_len, self.d_model)
        return self.wo(attn)
    def generate_causal_mask(self, x):
        mask = torch.triu(torch.ones(x.size(0), 1, x.size(1), x.size(1)), diagonal=1).to(x.device)
        return mask
class CMLP(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.fc1 = nn.Conv1d(d_model, d_model * 4, kernel_size=1)
        self.fc2 = nn.Conv1d(d_model * 4, d_model, kernel_size=1)
        self.act = nn.SiLU()
    def forward(self, x):
        return self.fc2(self.act(self.fc1(x.transpose(1, 2)))).transpose(1, 2)
class Block(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        self.d_model = d_model
        self.n_head = n_head
        self.attention = Attention(n_head, d_model)
        self.cmlp = CMLP(d_model)
    def forward(self, input_ids, attention_mask=None):
        hidden_states = input_ids
        residual = hidden_states
        hidden_states = self.attention(
            hidden_states,
            hidden_states,
            hidden_states,
            attention_mask=attention_mask
        )
        hidden_states = hidden_states + residual
        residual = hidden_states
        hidden_states = self.cmlp(hidden_states)
        hidden_states = hidden_states + residual
        return hidden_states
class Model(nn.Module):
    def __init__(self,
                vocab_size,
                d_model,
                n_layer,
                n_head):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.decoder_layers = DecoderSequential()
        for i in range(n_layer):
            decoder = Block(d_model, n_head)
            self.decoder_layers.add_module(f"decoder_layer {i}", decoder)
        self.lm_head = nn.Linear(d_model, vocab_size)
    def forward(self, input_ids, attention_mask=None, labels=None):
        hidden_states = input_ids
        hidden_states = self.embedding(hidden_states)
        hidden_states = self.decoder_layers(hidden_states, attention_mask=attention_mask)
        logits = self.lm_head(hidden_states)
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1))
        return loss, logits

The code attached implements a simple decoder-only transformer that will be trained on a text completion task. However, I’ve noticed an issue where the attention mechanism attends to future tokens. I have made sure the attention mask is being passed correctly, and I also get a feeling the issue may not have to do with the attention mechanism itself. Can someone help me resolve this issue?