How to compile the generate method with PT 2.0?

We expect that torch compiling the generate method will make it 30-100% faster (according to the official PT statement):

generate_fn = torch.compile(model.generate)

However, when benchmarking with the Whisper model, generation is slower using the torch compile generate vs the un-compiled one. Results here for a 16GB T4: codesnippets/benchmark_whisper_generate_torch_compile.ipynb at main 路 sanchit-gandhi/codesnippets 路 GitHub

Note that we do not measure the compilation time when benchmarking for a fair comparison.

The profiler suggests that the past_key_values are giving size mismatches (and recompilations?). Are we missing something here with how we set-up / compile the generate method?

Code to reproduce:
import torch
from torch._dynamo.utils import CompileProfiler

from transformers import WhisperForConditionalGeneration

from tqdm import tqdm

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")"cuda").half()

SEQ_LEN = 25

# random input features - doesn't matter what they are since we'll generate to fix length
input_features = torch.randn((BATCH_SIZE, 80, 3000))
input_features ="cuda").half()

# benchamrk vanilla generate
for i in tqdm(range(NUM_BATCHES)):
    pred_ids = model.generate(input_features, max_new_tokens=SEQ_LEN, min_new_tokens=SEQ_LEN)

profiler = CompileProfiler()
generate_fn = torch.compile(model.generate, backend=profiler)

# compilation step
pred_ids = generate_fn(input_features, max_new_tokens=SEQ_LEN, min_new_tokens=SEQ_LEN)

# benchmark compiled generate
for i in tqdm(range(100)):
    pred_ids = generate_fn(input_features, max_new_tokens=SEQ_LEN, min_new_tokens=SEQ_LEN)

# profiler report