Use_cache (and past_key_values) in GPT2 leads to slower inference?

Hi, I am trying to see the benefit of using use_cache in transformers. While it makes sense to cache keys and values, I am not sure why the following code shows that use_cache=True actually leads to more inference time than use_cache=False

from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch
import torch.nn as nn
import time
import numpy as np

device = "cuda" if torch.cuda.is_available() else "cpu"
output_lens = [50, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000]
bsz = 1
print(f"Device used: {device}")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2").to(device=device)
model.eval()

print("Inference for GPT-2...")
generated = tokenizer.encode("AF")
context = torch.tensor([generated]).to(device=device)
past = None
times_gpt = []
warmup=model(context)
torch.cuda.synchronize()
t = time.time()
with torch.no_grad():
    for i in range(1, output_lens[-1] + 1):
        outputs = model(context, use_cache=False)
        token = torch.argmax(outputs.logits[-1, :])
        #generated += [token.tolist()]
        context = token.unsqueeze(0)
        if i in output_lens:
            torch.cuda.synchronize()
            times_gpt.append(time.time() - t)

times_gpt_cache = []
torch.cuda.synchronize()
t = time.time()
with torch.no_grad():
    for i in range(1, output_lens[-1] + 1):
        outputs = model(context, past_key_values=past, use_cache=True)
        token = torch.argmax(outputs.logits[-1, :])
        #generated += [token.tolist()]
        past = outputs.past_key_values
        context = token.unsqueeze(0)
        if i in output_lens:
            torch.cuda.synchronize()
            times_gpt_cache.append(time.time() - t)
print("Nb decoded tokens, time GPT2, time GPT2 with cache")
for (nb_tokens, time_gpt, time_gpt_cache) in zip(
    output_lens,
    times_gpt,
    times_gpt_cache
):
    print(nb_tokens, time_gpt, time_gpt_cache)

on a single GPU (2080Ti), I got the following output:

Nb decoded tokens, time GPT2, time GPT2 with cache
50 0.29962658882141113 0.31832337379455566
100 0.5977451801300049 0.6370909214019775
200 1.1943321228027344 1.2762291431427002
300 1.7883381843566895 1.9125475883483887
400 2.3850345611572266 2.54799222946167
500 2.9833881855010986 3.180372476577759
600 3.579049587249756 3.813371419906616
700 4.1746296882629395 4.4520862102508545
800 4.769005537033081 5.087486267089844
900 5.365581035614014 5.725952625274658
1000 5.960651874542236 6.363517761230469

use_cache leads to longer inference time. I am using pytorch 2.0, transformers 4.27.4

Why is this happening? Thanks!

@joaogante Hi, would you be able to confirm or check this by any chance? Thanks!