Hello!
I’m benchmarking inference performance using Whisper and the .generate()
method, switching between using/not using the k-v cache).
My understanding is that when using the cache, inference should be faster (since we don’t recompute k-v states and cache them instead), but VRAM usage higher (since we keep the cached tensors in memory).
However, I’m finding that when using cache that inference is faster, but VRAM stays the same
Here are my results with/without cache for the tiny and base Whisper checkpoints:
Inf time with | Inf time without | VRAM with | VRAM without | |
---|---|---|---|---|
tiny | 9.0 | 12.0 | 1381 | 1381 |
base | 11.3 | 18.4 | 1523 | 1523 |
These experiments are run with greedy decoding, batch size of 1 and 73 eval samples on a 16GB V100. I’m computing VRAM by calling nvidia-smi
and monitoring how much usage there is on the GPU.
Is this as expected? Or should we see lower VRAM without cache?
Notebook: codesnippets/benchmark_whisper_cache.ipynb at main · sanchit-gandhi/codesnippets · GitHub
Code snippet to reproduce:
from datasets import load_dataset
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
import torch
from torch.utils.data import DataLoader
import numpy as np
import time
from tqdm import tqdm
import subprocess as sp
import os
import sched
checkpoint_id = "openai/whisper-tiny.en"
processor = WhisperProcessor.from_pretrained(checkpoint_id)
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_id)
model.to("cuda")
model.half()
librispeech = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
def preprocess(batch):
batch["input_features"] = processor(batch["audio"]["array"], sampling_rate=16000, return_tensors="pt").input_features[0]
return batch
dataset_processed = librispeech.map(preprocess, remove_columns=librispeech.column_names)
dataloader = DataLoader(dataset_processed.with_format("torch"), batch_size=1)
def get_gpu_memory():
"""
Python equivalent of nvidia-smi, copied from https://stackoverflow.com/a/67722676
and verified as being equivalent ✅
"""
output_to_list = lambda x: x.decode('ascii').split('\n')[:-1]
COMMAND = "nvidia-smi --query-gpu=memory.used --format=csv"
try:
memory_use_info = output_to_list(sp.check_output(COMMAND.split(),stderr=sp.STDOUT))[1:]
except sp.CalledProcessError as e:
raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
memory_use_values = [int(x.split()[0]) for i, x in enumerate(memory_use_info)]
return memory_use_values
# benchmark generation with cache
start = time.time()
for batch in tqdm(dataloader):
predicted_ids = model.generate(batch["input_features"].to("cuda").half(), max_new_tokens=128, use_cache=True)
runtime = time.time() - start
print("Runtime with: ", runtime)
print("VRAM with: ", get_gpu_memory()[0])
# if we don't delete and re-load the model the GPU use is lower the second time round: warm-up effects?
del model
torch.cuda.empty_cache()
# benchmark without cache
model = WhisperForConditionalGeneration.from_pretrained(checkpoint_id)
model.to("cuda")
model.half()
start = time.time()
for batch in tqdm(dataloader):
predicted_ids = model.generate(batch["input_features"].to("cuda").half(), max_new_tokens=128, use_cache=False)
runtime = time.time() - start
print("Runtime without: ", runtime)
print("VRAM without: ", get_gpu_memory()[0])
Print Output:
Runtime with: 8.990428924560547
VRAM with: 1381
Runtime without: 11.993675231933594
VRAM without: 1381
Thanks!