How to compile the generate method with PT 2.0?

We expect that torch compiling the generate method will make it 30-100% faster (according to the official PT statement):

generate_fn = torch.compile(model.generate)

However, when benchmarking with the Whisper model, generation is slower using the torch compile generate vs the un-compiled one. Results here for a 16GB T4: codesnippets/benchmark_whisper_generate_torch_compile.ipynb at main · sanchit-gandhi/codesnippets · GitHub

Note that we do not measure the compilation time when benchmarking for a fair comparison.

The profiler suggests that the past_key_values are giving size mismatches (and recompilations?). Are we missing something here with how we set-up / compile the generate method?

Code to reproduce:
import torch
from torch._dynamo.utils import CompileProfiler

from transformers import WhisperForConditionalGeneration

from tqdm import tqdm


model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
model.to("cuda").half()

BATCH_SIZE = 1
SEQ_LEN = 25
NUM_BATCHES = 100

# random input features - doesn't matter what they are since we'll generate to fix length
input_features = torch.randn((BATCH_SIZE, 80, 3000))
input_features = input_features.to("cuda").half()

# benchamrk vanilla generate
for i in tqdm(range(NUM_BATCHES)):
    pred_ids = model.generate(input_features, max_new_tokens=SEQ_LEN, min_new_tokens=SEQ_LEN)

profiler = CompileProfiler()
generate_fn = torch.compile(model.generate, backend=profiler)

# compilation step
pred_ids = generate_fn(input_features, max_new_tokens=SEQ_LEN, min_new_tokens=SEQ_LEN)

# benchmark compiled generate
for i in tqdm(range(100)):
    pred_ids = generate_fn(input_features, max_new_tokens=SEQ_LEN, min_new_tokens=SEQ_LEN)

# profiler report
print(profiler.report())
1 Like