Whisper decoder is slow for ASR task

I have followed this blog to finetune the ASR model.
The training is working fine. However, the decoding time is very slow.

Are there hyperparameters to be optimized for speeding up the decoder of Whisper?
Or is there a possibility to customize the decoder of Whisper?

1 Like

Hey @ksoky!

Seq2Seq models perform generate text through autoregressive generation of the decoder (see Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers for details). So, we perform as many forward passes of the decoder as tokens generated.

Running generation in “greedy” mode will be much faster than beam search (we use greedy by default).

You can also explore reducing the “max_length”:

model.config.max_length = 100

Will generate 100 tokens max. But this will almost certainly reduce your overall performance, as you’ll truncate some sentences short.

Are you running inference on GPU? It shouldn’t be too slow with the “small” checkpoint on most GPU devices!

Alternatively, you can try training one of the smaller checkpoints (“base” or “tiny”) for faster inference.

Dear @sanchit-gandhi,

Thanks for your suggestions.
I will try again and come back soon.

Best,

How did it turn out? I ran into the same issue, where it feels like it is taking 3-6x longer to predict than whisper medium.en took. I have a need of rapid near live transcriptions.
This is my code so far:

import os
import csv
import whisper
from transformers import WhisperForConditionalGeneration, WhisperConfig, WhisperModel, WhisperProcessor, WhisperTokenizer, WhisperFeatureExtractor
from transformers import pipeline

path_to_model = 'path_to_model'

model = WhisperForConditionalGeneration.from_pretrained(path_to_model)
model.config.max_length=150
processor = WhisperProcessor.from_pretrained(path_to_model, language="english", task="automatic-speech-recognition",
                                             generation_num_beams=1)
tokenizer = WhisperTokenizer.from_pretrained(path_to_model,
                                             generation_num_beams=1)
featureextractor = WhisperFeatureExtractor.from_pretrained(path_to_model)

pipe = pipeline(
    task = 'automatic-speech-recognition',
    model = model,
    tokenizer = tokenizer,
    feature_extractor = featureextractor,
    chunk_length_s=15
)

def transcribe(audio):
    text = pipe(audio)["text"]
    return text

transcription = transcribe(file_path)