Generate: using k-v cache is faster but no difference to memory usage

Hello! :wave:

I’m benchmarking inference performance using Whisper and the .generate() method, switching between using/not using the k-v cache).

My understanding is that when using the cache, inference should be faster (since we don’t recompute k-v states and cache them instead), but VRAM usage higher (since we keep the cached tensors in memory).

However, I’m finding that when using cache that inference is faster, but VRAM stays the same :face_with_monocle:

Here are my results with/without cache for the tiny and base Whisper checkpoints:

Inf time with Inf time without VRAM with VRAM without
tiny 9.0 12.0 1381 1381
base 11.3 18.4 1523 1523

These experiments are run with greedy decoding, batch size of 1 and 73 eval samples on a 16GB V100. I’m computing VRAM by calling nvidia-smi and monitoring how much usage there is on the GPU.

Is this as expected? Or should we see lower VRAM without cache?

Notebook: codesnippets/benchmark_whisper_cache.ipynb at main · sanchit-gandhi/codesnippets · GitHub

Code snippet to reproduce:
from datasets import load_dataset
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor

import torch
from torch.utils.data import DataLoader
import numpy as np

import time
from tqdm import tqdm
import subprocess as sp
import os
import sched

checkpoint_id = "openai/whisper-tiny.en"
processor = WhisperProcessor.from_pretrained(checkpoint_id)

model = WhisperForConditionalGeneration.from_pretrained(checkpoint_id)
model.to("cuda")
model.half()

librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")

def preprocess(batch):    
    batch["input_features"] = processor(batch["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features[0]
    return batch

dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)

dataloader = DataLoader(dataset_processed.with_format("torch"), batch_size=1)


def get_gpu_memory():
    """
    Python equivalent of nvidia-smi, copied from https://stackoverflow.com/a/67722676
    and verified as being equivalent ✅
    """
    output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
    
    COMMAND = "nvidia-smi --query-gpu=memory.used --format=csv"
    
    try:
        memory_use_info = output_to_list(sp.check_output(COMMAND.split(),stderr=sp.STDOUT))[1:]
    
    except sp.CalledProcessError as e:
        raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
    
    memory_use_values = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
    return memory_use_values

# benchmark generation with cache

start = time.time()
for batch in tqdm(dataloader):
    predicted_ids = model.generate(batch["input_features"].to("cuda").half(), max_new_tokens=128, use_cache=True)
runtime = time.time() - start

print("Runtime with: ", runtime)
print("VRAM with: ", get_gpu_memory()[0])

# if we don't delete and re-load the model the GPU use is lower the second time round: warm-up effects?
del model
torch.cuda.empty_cache()

# benchmark without cache

model = WhisperForConditionalGeneration.from_pretrained(checkpoint_id)
model.to("cuda")
model.half()

start = time.time()
for batch in tqdm(dataloader):
    predicted_ids = model.generate(batch["input_features"].to("cuda").half(), max_new_tokens=128, use_cache=False)
runtime = time.time() - start

print("Runtime without: ", runtime)
print("VRAM without: ", get_gpu_memory()[0])

Print Output:

Runtime with:  8.990428924560547
VRAM with:  1381
Runtime without:  11.993675231933594
VRAM without:  1381

Thanks!

1 Like

Nice write-up!

I think the decoder sequence length and the hidden states of the model might be too small to see a difference here in VRAM.

The reason VRAM should be higher when caching the k,v states is because we cache the projected k,v states of every layer. This means that our cache is of size:

2 * (hidden_size) * (num_layers) * (decoder_length)

For VRAM computation, this memory is more or less always added to the peak memory of the computation graph.

For comparison, we don’t have this memory when not caching. The memory we always have when not caching before doing the attention QK^T computation (which is probs the bottleneck) is 2 * (hidden_size) * 1 * (decoder_length) . Those are the q, v states right that are computed during attention.

=> I expect that here (num_layers), (hidden_size) and (decoder_length) are too small to make a difference.

The easiest thing to check here would be to use a bigger model and generate to much longer (set eos to None and generate to 256 tokens).

2 Likes

Overall this is an interesting finding though as it means that the k,v cache probably doesn’t play a big role in reducing VRAM for ASR and at that model size.

@sanchit-gandhi a few extra numbers – modifying your script to run on GPT-J with FP16 on an 3090, with input_ids.shape[1]=16 and max_new_tokens=256, we get:

  1. 14071MB of GPU usage with use_cache=False
  2. 13233MB of GPU usage with use_cache=True

The difference becomes more visible with large models and large sequence lengths :mag_right:

Thank you very much for the detailed response!

That makes sense that the difference in VRAM with/without using cache is not significant for a model with such low dimensionality.

Repeating the experiment with the large-v2 checkpoint (hidden_size=1280, num_layers=32) and generating to 256 tokens yields measurable differences in VRAM, albeit still only marginal:

VRAM with: 7597
VRAM without: 7515
Diff: 82

(all values in MB)

As we expect, the effect is amplified at 512 tokens, scaling (almost) linearly with decoder_length:

VRAM with: 7639
VRAM without: 7519
Diff: 120

ASR models tend to generate quite short decoder-lengths. For example, the average token length in the LibriSpeech validation corpus is just ~20 tokens. Setting the max length accordingly, we get:

VRAM with: 7515
VRAM without: 7511
Diff: 4

So pretty insignificant! My intuition is that since VRAM difference with/without cache is proportional to decoder-length, k-v cache doesn’t have a big effect on VRAM for ASR models, even for larger checkpoints.