Hi everyone,
I implemented a transformer model from scratch based on Andrej Karpathy’s video. After that, I added some improvements, such as parallelized multi-head attention and KV-Cache for efficient generation.
However, I am facing a problem: when I enable KV-Cache, the generated text is significantly worse than when KV-Cache is disabled.
I expected KV-Cache to improve speed while keeping output quality the same, but my implementation seems to degrade output quality.
I suspect there might be issues with how I handle:
- Key-value cache updates (maybe
torch.roll
is causing issues?) - Positional embeddings during cached generation (Am I offsetting them correctly?)
- Masking or attention scores (Am I incorrectly modifying attention weights?)
Here is my code:
batch_size = 64
block_size = 256
pos_emd_coe = 8
# max_iters = 5000
max_iters = 1000
eval_interval = 500
eval_iters = 200
lr = 3e-4
embed_dim = 384
n_layer = 6
dropout = 0.2
class FeedForward(nn.Module):
def __init__(self, embed):
super().__init__()
self.ff = nn.Sequential(
nn.Linear(embed, embed * 4),
nn.ReLU(),
nn.Linear(embed * 4, embed),
nn.Dropout(dropout),
)
def forward(self, x):
return self.ff(x)
class MultiHead(nn.Module):
def __init__(self, num_heads, head_size):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.query = nn.Linear(embed_dim, num_heads * head_size, bias=False)
self.key = nn.Linear(embed_dim, num_heads * head_size, bias=False)
self.value = nn.Linear(embed_dim, num_heads * head_size, bias=False)
self.dropout1 = nn.Dropout(dropout)
self.proj = nn.Linear(num_heads*head_size, num_heads*head_size)
self.dropout2 = nn.Dropout(dropout)
self.context_size = block_size
self.forward = self.forward_normal
self.k_cache, self.v_cache = None, None
self.valid_length = 0
def set_phase(self, mode='normal'):
if mode == 'normal':
self.forward = self.forward_normal
elif mode == 'gen':
self.forward = self.forward_gen
def set_context_size(self, context_size=None, batch_size=1):
if context_size:
self.context_size = context_size
device = next(self.parameters()).device
dim = (batch_size, self.num_heads, self.context_size, self.head_size)
self.k_cache = torch.zeros(dim, device=device)
self.v_cache = torch.zeros(dim, device=device)
self.valid_length = 0
def forward_normal(self, x):
B, T, C = x.shape
q, k, v = self.query(x), self.key(x), self.value(x)
q = q.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
k = k.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
v = v.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
wei = q @ k.transpose(-1, -2) * self.head_size**-0.5
wei = wei.masked_fill(torch.ones(T, T, device=device).tril() == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout1(wei)
out = wei @ v
out = out.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_size)
out = self.proj(out)
out = self.dropout2(out)
return out
def forward_gen(self, x):
B, T, C = x.shape
if (self.k_cache is None) or (self.v_cache is None) or (self.k_cache.shape[0] != B):
self.set_context_size(batch_size=B)
self.valid_length = min(self.valid_length + 1, self.context_size)
last_t = x[:, -1:, :]
q, k, v = self.query(last_t), self.key(last_t), self.value(last_t)
# (B, C//H, 1, H)
q = q.view(B, 1, self.num_heads, self.head_size).transpose(1, 2)
k = k.view(B, 1, self.num_heads, self.head_size).transpose(1, 2)
v = v.view(B, 1, self.num_heads, self.head_size).transpose(1, 2)
self.k_cache = torch.roll(self.k_cache, -1, 2)
self.v_cache = torch.roll(self.v_cache, -1, 2)
self.k_cache[:, :, -1, :] = k.squeeze()
self.v_cache[:, :, -1, :] = v.squeeze()
# (B, C//H, 1, context_size)
wei = q @ self.k_cache[:, :, -self.valid_length:].transpose(-1, -2) * self.head_size**-0.5
# wei = wei.masked_fill(torch.ones(T, T).tril() == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei = self.dropout1(wei)
out = wei @ self.v_cache[:, :, -self.valid_length:]
out = out.transpose(1, 2).contiguous().view(B, 1, C)
out = self.proj(out)
out = self.dropout2(out)
return out
class Block(nn.Module):
def __init__(self, embed_dim, n_head):
super().__init__()
head_size = embed_dim // n_head
self.sa_head = MultiHead(n_head, head_size)
self.ff = FeedForward(embed_dim)
self.ln1 = nn.LayerNorm(embed_dim)
self.ln2 = nn.LayerNorm(embed_dim)
def forward(self, x):
x = self.sa_head(self.ln1(x)) + x
x = self.ff(self.ln2(x)) + x
return x
class MyTransformer(nn.Module):
def __init__(self):
super().__init__()
self.token_embeding = nn.Embedding(vocab_size, embed_dim)
self.positional_embedding = nn.Embedding(block_size * pos_emd_coe, embed_dim)
self.sa_blocks = nn.Sequential(*[Block(embed_dim, 4) for _ in range(n_layer)])
self.ln = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size)
def set_context_size(self, context_size):
[b.sa_head.set_context_size(context_size) for b in self.sa_blocks]
def set_phase(self, mode):
[b.sa_head.set_phase(mode) for b in self.sa_blocks]
def forward(self, x, target=None, offset=0):
# x and target: (B, T)
B, T = x.shape
# embeddings: (B, T, C)
offset = torch.full((B, ), offset, device=device) if target is None else torch.randint(0, block_size * pos_emd_coe - T, (B, ), device=device)
position_emd = self.positional_embedding(offset[:, None] + torch.arange(T, device=device))
token_emd = self.token_embeding(x)
embeddings = token_emd + position_emd
embeddings = self.sa_blocks(embeddings)
# logits: (B, T, vocab_size)
logits = self.lm_head(self.ln(embeddings))
if target is not None:
B, T, C = logits.shape
loss = F.cross_entropy(logits.view(B*T, C), target.view(B*T))
else:
loss = None
return logits, loss
def generate(self, x, max_new_tokens):
# x: (B, T)
# y: (B, T + max_new_tokens)
for i in range(max_new_tokens):
# logits: (B, C)
logits, _ = self(x[:, -block_size:])
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
x = torch.cat((x, next_token), -1)
return x
def generate_kv_cache(self, x, max_new_tokens):
# x: (B, T)
# y: (B, T + max_new_tokens)
self.set_context_size(block_size)
self.set_phase('gen')
for i in range(max_new_tokens):
# logits: (B, C)
# logits, _ = self(x[:, -block_size:], offset=x[:, :-block_size].shape[-1])
logits, _ = self(x[:, -1:], offset=i)
logits = logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1)
x = torch.cat((x, next_token), -1)
self.set_phase('normal')
return x
Normal Generation (works well):
model.eval()
with torch.no_grad():
x = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decoder(model.generate(x, 2000)[0].tolist()))
# print(decoder(model.generate(x, block_size)[0].tolist()))
KV-Cache Generation (produces worse output):
model.eval()
with torch.no_grad():
x = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decoder(model.generate_kv_cache(x, 2000)[0].tolist()))
# print(decoder(model.generate(x, block_size)[0].tolist()))
Does anyone see what might be wrong in my KV-Cache implementation?
Am I handling positional embeddings and caching correctly?
Any insights would be greatly appreciated! Thanks.