Transformer KV-Cache Produces Worse Output Than Normal Generation – Why?

Hi everyone,

I implemented a transformer model from scratch based on Andrej Karpathy’s video. After that, I added some improvements, such as parallelized multi-head attention and KV-Cache for efficient generation.

However, I am facing a problem: when I enable KV-Cache, the generated text is significantly worse than when KV-Cache is disabled.
I expected KV-Cache to improve speed while keeping output quality the same, but my implementation seems to degrade output quality.

I suspect there might be issues with how I handle:

  1. Key-value cache updates (maybe torch.roll is causing issues?)
  2. Positional embeddings during cached generation (Am I offsetting them correctly?)
  3. Masking or attention scores (Am I incorrectly modifying attention weights?)

Here is my code:

batch_size = 64
block_size = 256
pos_emd_coe = 8

# max_iters = 5000
max_iters = 1000
eval_interval = 500
eval_iters = 200

lr = 3e-4

embed_dim = 384
n_layer = 6
dropout = 0.2

class FeedForward(nn.Module):
    def __init__(self, embed):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(embed, embed * 4),
            nn.ReLU(),
            nn.Linear(embed * 4, embed),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.ff(x)

class MultiHead(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.num_heads = num_heads
        self.head_size = head_size

        self.query = nn.Linear(embed_dim, num_heads * head_size, bias=False)
        self.key = nn.Linear(embed_dim, num_heads * head_size, bias=False)
        self.value = nn.Linear(embed_dim, num_heads * head_size, bias=False)
        self.dropout1 = nn.Dropout(dropout)

        self.proj = nn.Linear(num_heads*head_size, num_heads*head_size)
        self.dropout2 = nn.Dropout(dropout)

        self.context_size = block_size
        self.forward = self.forward_normal
        self.k_cache, self.v_cache = None, None
        self.valid_length = 0

    def set_phase(self, mode='normal'):
        if mode == 'normal':
            self.forward = self.forward_normal
        elif mode == 'gen':
            self.forward = self.forward_gen

    def set_context_size(self, context_size=None, batch_size=1):
        if context_size:
            self.context_size = context_size
        device = next(self.parameters()).device
        dim = (batch_size, self.num_heads, self.context_size, self.head_size)
        self.k_cache = torch.zeros(dim, device=device)
        self.v_cache = torch.zeros(dim, device=device)
        self.valid_length = 0

    def forward_normal(self, x):
        B, T, C = x.shape
        q, k, v = self.query(x), self.key(x), self.value(x)
        q = q.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
        k = k.view(B, T, self.num_heads, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.num_heads, self.head_size).transpose(1, 2)

        wei = q @ k.transpose(-1, -2) * self.head_size**-0.5

        wei = wei.masked_fill(torch.ones(T, T, device=device).tril() == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout1(wei)

        out = wei @ v
        out = out.transpose(1, 2).contiguous().view(B, T, self.num_heads * self.head_size)

        out = self.proj(out)
        out = self.dropout2(out)
        return out

    def forward_gen(self, x):
        B, T, C = x.shape

        if (self.k_cache is None) or (self.v_cache is None) or (self.k_cache.shape[0] != B):
            self.set_context_size(batch_size=B)

        self.valid_length = min(self.valid_length + 1, self.context_size)

        last_t = x[:, -1:, :]
        q, k, v = self.query(last_t), self.key(last_t), self.value(last_t)

        # (B, C//H, 1, H)
        q = q.view(B, 1, self.num_heads, self.head_size).transpose(1, 2)
        k = k.view(B, 1, self.num_heads, self.head_size).transpose(1, 2)
        v = v.view(B, 1, self.num_heads, self.head_size).transpose(1, 2)

        self.k_cache = torch.roll(self.k_cache, -1, 2)
        self.v_cache = torch.roll(self.v_cache, -1, 2)

        self.k_cache[:, :, -1, :] = k.squeeze()
        self.v_cache[:, :, -1, :] = v.squeeze()

        # (B, C//H, 1, context_size)
        wei = q @ self.k_cache[:, :, -self.valid_length:].transpose(-1, -2) * self.head_size**-0.5

        # wei = wei.masked_fill(torch.ones(T, T).tril() == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        wei = self.dropout1(wei)

        out = wei @ self.v_cache[:, :, -self.valid_length:]
        out = out.transpose(1, 2).contiguous().view(B, 1, C)

        out = self.proj(out)
        out = self.dropout2(out)
        return out

class Block(nn.Module):
    def __init__(self, embed_dim, n_head):
        super().__init__()
        head_size = embed_dim // n_head
        self.sa_head = MultiHead(n_head, head_size)
        self.ff = FeedForward(embed_dim)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x):
        x = self.sa_head(self.ln1(x)) + x
        x = self.ff(self.ln2(x)) + x
        return x

class MyTransformer(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embeding = nn.Embedding(vocab_size, embed_dim)
        self.positional_embedding = nn.Embedding(block_size * pos_emd_coe, embed_dim)

        self.sa_blocks = nn.Sequential(*[Block(embed_dim, 4) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def set_context_size(self, context_size):
        [b.sa_head.set_context_size(context_size) for b in self.sa_blocks]

    def set_phase(self, mode):
        [b.sa_head.set_phase(mode) for b in self.sa_blocks]

    def forward(self, x, target=None, offset=0):
        # x and target: (B, T)
        B, T = x.shape
        # embeddings: (B, T, C)
        offset = torch.full((B, ), offset, device=device) if target is None else torch.randint(0, block_size * pos_emd_coe - T, (B, ), device=device)
        position_emd = self.positional_embedding(offset[:, None] + torch.arange(T, device=device))
        token_emd = self.token_embeding(x)
        embeddings = token_emd + position_emd

        embeddings = self.sa_blocks(embeddings)

        # logits: (B, T, vocab_size)
        logits = self.lm_head(self.ln(embeddings))

        if target is not None:
            B, T, C = logits.shape
            loss = F.cross_entropy(logits.view(B*T, C), target.view(B*T))
        else:
            loss = None

        return logits, loss

    def generate(self, x, max_new_tokens):
        # x: (B, T)
        # y: (B, T + max_new_tokens)
        for i in range(max_new_tokens):
            # logits: (B, C)
            logits, _ = self(x[:, -block_size:])
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            x = torch.cat((x, next_token), -1)
        return x

    def generate_kv_cache(self, x, max_new_tokens):
        # x: (B, T)
        # y: (B, T + max_new_tokens)
        self.set_context_size(block_size)
        self.set_phase('gen')
        for i in range(max_new_tokens):
            # logits: (B, C)
            # logits, _ = self(x[:, -block_size:], offset=x[:, :-block_size].shape[-1])
            logits, _ = self(x[:, -1:], offset=i)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, 1)
            x = torch.cat((x, next_token), -1)
        self.set_phase('normal')
        return x

Normal Generation (works well):

model.eval()
with torch.no_grad():
    x = torch.zeros((1, 1), dtype=torch.long, device=device)
    print(decoder(model.generate(x, 2000)[0].tolist()))
    # print(decoder(model.generate(x, block_size)[0].tolist()))

KV-Cache Generation (produces worse output):

model.eval()
with torch.no_grad():
    x = torch.zeros((1, 1), dtype=torch.long, device=device)
    print(decoder(model.generate_kv_cache(x, 2000)[0].tolist()))
    # print(decoder(model.generate(x, block_size)[0].tolist()))

Does anyone see what might be wrong in my KV-Cache implementation?
Am I handling positional embeddings and caching correctly?

Any insights would be greatly appreciated! Thanks.

2 Likes

Hey @mirzaim,
I couldn’t access the exact reference code or helper functions you used, but I’ve taken a stab at adapting your implementation with some replacements, swapping in the GPT2 tokenizer, for instance and worked on refining the KV-Cache setup. The good news is, the KV-Cache now runs smoothly, and both the cached and non-cached outputs match perfectly, as you can see in the results. Plus, we’re getting a >3x speed boost which is to be expected when we do KV-Caching.

Note: Don’t worry about the output content itself, it’s just what the model spits out with untrained weights based on your architecture. The key point is that the outputs are identical, showing the KV-Cache is working as intended.

This is the corrected code for your Reference, which implement KV-Caching correctly and achieve a >3x speed boost:

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import GPT2Tokenizer
import time

block_size = 256
embed_dim = 384
n_layer = 6
n_head = 4
dropout = 0.2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
random_seed = 1234

torch.manual_seed(random_seed)
if device == 'cuda':
    torch.cuda.manual_seed_all(random_seed)

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
vocab_size = tokenizer.vocab_size
tokenizer.pad_token = tokenizer.eos_token

class FeedForward(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.ReLU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(dropout),
        )
    def forward(self, x):
        return self.ff(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, n_head):
        super().__init__()
        self.n_head = n_head
        self.head_dim = embed_dim // n_head
        self.qkv = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
        self.cache_k = None
        self.cache_v = None
        self.cache_pos = 0

    def reset_cache(self):
        self.cache_k = None
        self.cache_v = None
        self.cache_pos = 0

    def forward(self, x, use_cache=False):
        B, T, C = x.shape
        qkv = self.qkv(x)
        q, k, v = torch.split(qkv, C, dim=2)
        q = q.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.n_head, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_dim).transpose(1, 2)

        if use_cache:
            if self.cache_k is None:
                self.cache_k = torch.zeros(
                    (B, self.n_head, block_size, self.head_dim),
                    device=x.device, dtype=x.dtype
                )
                self.cache_v = torch.zeros_like(self.cache_k)
                self.cache_pos = 0

            if self.cache_pos + T > block_size:
                self.cache_k.zero_()
                self.cache_v.zero_()
                self.cache_pos = 0

            old_k = self.cache_k[:, :, :self.cache_pos]
            old_v = self.cache_v[:, :, :self.cache_pos]
            full_k = torch.cat([old_k, k], dim=2)
            full_v = torch.cat([old_v, v], dim=2)
            total_len = self.cache_pos + T
            if T == 1:
                causal_mask = torch.ones((T, total_len), device=x.device, dtype=torch.bool)
            else:
                causal_mask = torch.tril(torch.ones((T, total_len), device=x.device, dtype=torch.bool))
            alpha = (q @ full_k.transpose(-2, -1)) * (self.head_dim**-0.5)
            alpha = alpha.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            alpha = F.softmax(alpha, dim=-1)
            alpha = self.dropout(alpha)
            out = alpha @ full_v
            self.cache_k[:, :, self.cache_pos:self.cache_pos+T] = k
            self.cache_v[:, :, self.cache_pos:self.cache_pos+T] = v
            self.cache_pos += T
        else:
            full_k = k
            full_v = v
            causal_mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool))
            alpha = (q @ full_k.transpose(-2, -1)) * (self.head_dim**-0.5)
            alpha = alpha.masked_fill(~causal_mask.unsqueeze(0).unsqueeze(0), float('-inf'))
            alpha = F.softmax(alpha, dim=-1)
            alpha = self.dropout(alpha)
            out = alpha @ full_v
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.proj(out)

class Block(nn.Module):
    def __init__(self, embed_dim, n_head):
        super().__init__()
        self.sa_head = MultiHeadAttention(embed_dim, n_head)
        self.ff = FeedForward(embed_dim)
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)

    def forward(self, x, use_cache=False):
        x = self.sa_head(self.ln1(x), use_cache=use_cache) + x
        x = self.ff(self.ln2(x)) + x
        return x

class Transformer(nn.Module):
    def __init__(self):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, embed_dim)
        self.pos_emb = nn.Embedding(block_size, embed_dim)
        self.keys = nn.ModuleList([Block(embed_dim, n_head) for _ in range(n_layer)])
        self.ln = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx, targets=None, use_cache=False, pos_start=0):
        B, T = idx.shape
        positions = (pos_start + torch.arange(T, device=device)) % block_size
        tok_emb = self.tok_emb(idx)
        pos_emb = self.pos_emb(positions)
        x = tok_emb + pos_emb
        for block in self.keys:
            x = block(x, use_cache=use_cache)
        x = self.ln(x)
        logits = self.head(x)
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, vocab_size), targets.view(-1))
            return logits, loss
        return logits, None

    def generate(self, idx, max_new_tokens=50):
        start_time = time.time()
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            seq_len = idx.size(1)
            pos_start = max(0, seq_len - block_size)
            logits, _ = self(idx_cond, use_cache=False, pos_start=pos_start)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_token], dim=1)
        elapsed = time.time() - start_time
        return idx, elapsed

    def generate_with_cache(self, idx, max_new_tokens=50):
        start_time = time.time()
        for block in self.keys:
            block.sa_head.reset_cache()
        seq_len = idx.size(1)
        logits, _ = self(idx, use_cache=True, pos_start=0)
        next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
        idx = torch.cat([idx, next_token], dim=1)
        seq_len += 1
        for _ in range(1, max_new_tokens):
            last_token = idx[:, -1:]
            pos_start = (seq_len - 1) % block_size
            logits, _ = self(last_token, use_cache=True, pos_start=pos_start)
            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)
            idx = torch.cat([idx, next_token], dim=1)
            seq_len += 1
            if idx.size(1) > block_size:
                idx = idx[:, -block_size:]
                seq_len = block_size
        elapsed = time.time() - start_time
        return idx, elapsed

if __name__ == "__main__":
    model = Transformer().to(device)
    model.eval()
    input_text = "The answer to life is indeed HuggingFace, dont you agree"
    input_ids = tokenizer.encode(input_text, return_tensors='pt').to(device)
    with torch.inference_mode():
        out_normal, time_normal = model.generate(input_ids.clone(), max_new_tokens=50)
        out_cache, time_cache = model.generate_with_cache(input_ids.clone(), max_new_tokens=50)
    print("Normal Generation:")
    print(tokenizer.decode(out_normal[0]))
    print("\nKV Cache Generation:")
    print(tokenizer.decode(out_cache[0]))
    print(f"\nNormal Gen Speed: {50/time_normal:.2f} tokens/sec")
    print(f"KV Cache Gen Speed: {50/time_cache:.2f} tokens/sec")
    match = torch.all(out_normal == out_cache)
    print(f"\nOutputs Match? {match.item()}")

You were pretty much on the correct track about where things might’ve gone off track. I’ve grouped the main issues into these four points, along with how the fixed version tackles them.

(1)Incorrect KV-Cache Updates
In your original implementation, you used torch.roll(self.k_cache, -1, 2) and torch.roll(self.v_cache, -1, 2) to shift the KV-Cache and add new keys/values. This scrambled the order of past tokens, so the attention mechanism worked with a messed-up history, leading to bad outputs.

The fixed code uses full_k = torch.cat([old_k, k], dim=2) and self.cache_k[:, :, self.cache_pos:self.cache_pos+T] = k to append new keys/values, keeping the sequence correct so attention uses the right context.

(2)Missing Causal Mask in Cached Generation
Your original implementation’s forward_gen method calculated attention weights with wei = q @ self.k_cache[…] without a causal mask, letting the model see future tokens. This broke the autoregressive rule, throwing off the generation.

The corrected code adds a causal mask with if T == 1: causal_mask = torch.ones((T, total_len)) else causal_mask = torch.tril(...), ensuring only past tokens are seen, keeping generation on track.

(3)Improper Prompt Processing
In your original implementation, the generate_kv_cache approach processed just the last token each time with logits, _ = self(x[:, -1:], offset=i), even for the first token after the prompt. This skipped the full prompt context, making the first prediction weak.

The updated code begins with logits, _ = self(idx, use_cache=True, pos_start=0) to process the whole prompt and use its logits, giving a solid start that matches non-cached generation.

(4)Complex Positional Embeddings:
Your original implementation opted for a big embedding table with self.positional_embedding = nn.Embedding(block_size * pos_emd_coe, embed_dim) and an offset offset[:, None] + torch.arange(T). This risked misaligning positions with the cache, confusing the model.

The fixed code simplifies it to self.pos_emb = nn.Embedding(block_size, embed_dim) and positions = (pos_start + torch.arange(T)) % block_size, aligning positions cleanly with the cache.

Hope this clarifies the issues in the original code and that the corrected version supports your use case effectively.
Hope this helps :hugs:!

1 Like