import torch
import torch.nn as nn
import os
import warnings
import torch.nn.functional as F
import math
class DecoderSequential(nn.Sequential):
@staticmethod
def is_key_duplicate(dictionary, key_to_check):
count = 0
for key in dictionary:
if key == key_to_check:
count += 1
if count > 1:
return True
return False
def forward(self, input_ids, **kwargs):
for module in self:
try:
input = self.embedding(input_ids)
except Exception as e:
input = input_ids
pass
if "input_ids" in kwargs:
kwargs.pop("input_ids", None)
if isinstance(module, nn.Embedding):
continue
try:
output = module(input_ids=input, **kwargs)
except:
output = module(input, **kwargs)
try:
input = output.last_hidden_state
except:
input = output[0]
return output
class Attention(nn.Module):
def __init__(self, n_head, d_model):
super().__init__()
assert d_model % n_head == 0
self.n_head = n_head
self.head_dim = d_model // n_head
self.d_model = d_model
self.wq = nn.Linear(d_model, d_model)
self.wk = nn.Linear(d_model, d_model)
self.wv = nn.Linear(d_model, d_model)
self.wo = nn.Linear(d_model, d_model)
def forward(self, q, k, v, attention_mask=None):
bsz, q_len, _ = q.size()
if attention_mask is None:
attention_mask = self.generate_causal_mask(q)
q = self.wq(q)
k = self.wk(k)
v = self.wv(v)
q = q.view(bsz, q_len, self.n_head, self.head_dim).transpose(1,2)
k = k.view(bsz, q_len, self.n_head, self.head_dim).transpose(1,2)
v = v.view(bsz, q_len, self.n_head, self.head_dim).transpose(1,2)
attn = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(self.head_dim)
attn = attn + attention_mask
attn = F.softmax(attn, dim=-1)
attn = torch.matmul(attn, v).transpose(1, 2).contiguous()
attn = attn.reshape(bsz, q_len, self.d_model)
return self.wo(attn)
def generate_causal_mask(self, x):
mask = torch.triu(torch.ones(x.size(0), 1, x.size(1), x.size(1)), diagonal=1).to(x.device)
return mask
class CMLP(nn.Module):
def __init__(self, d_model):
super().__init__()
self.fc1 = nn.Conv1d(d_model, d_model * 4, kernel_size=1)
self.fc2 = nn.Conv1d(d_model * 4, d_model, kernel_size=1)
self.act = nn.SiLU()
def forward(self, x):
return self.fc2(self.act(self.fc1(x.transpose(1, 2)))).transpose(1, 2)
class Block(nn.Module):
def __init__(self, d_model, n_head):
super().__init__()
self.d_model = d_model
self.n_head = n_head
self.attention = Attention(n_head, d_model)
self.cmlp = CMLP(d_model)
def forward(self, input_ids, attention_mask=None):
hidden_states = input_ids
residual = hidden_states
hidden_states = self.attention(
hidden_states,
hidden_states,
hidden_states,
attention_mask=attention_mask
)
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.cmlp(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class Model(nn.Module):
def __init__(self,
vocab_size,
d_model,
n_layer,
n_head):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.decoder_layers = DecoderSequential()
for i in range(n_layer):
decoder = Block(d_model, n_head)
self.decoder_layers.add_module(f"decoder_layer {i}", decoder)
self.lm_head = nn.Linear(d_model, vocab_size)
def forward(self, input_ids, attention_mask=None, labels=None):
hidden_states = input_ids
hidden_states = self.embedding(hidden_states)
hidden_states = self.decoder_layers(hidden_states, attention_mask=attention_mask)
logits = self.lm_head(hidden_states)
if labels is not None:
loss = nn.CrossEntropyLoss()(logits.view(-1, logits.size(-1)), labels.view(-1))
return loss, logits
The code attached implements a simple decoder-only transformer that will be trained on a text completion task. However, I’ve noticed an issue where the attention mechanism attends to future tokens. I have made sure the attention mask is being passed correctly, and I also get a feeling the issue may not have to do with the attention mechanism itself. Can someone help me resolve this issue?