Adding prompt / context to Whisper with Huggingface Transformers

For anyone who is still looking for a solution.
I think I manged to make it work. Here is an example with whisper medium in French.

First extract the prompt ids with get_prompt_ids() method:

import torch
prev_text = "est ce qu'ils sont choqués ?"
prompt_ids = torch.tensor(whisper_processor.get_prompt_ids(prev_text), device=device)
print(whisper_processor.decode(prompt_ids, skip_special_tokens=False))

“<|startofprev|> est ce qu’ils sont choqués ?”

Then inject this prompt in inference. this inference uses no sampling to remain deterministic in order to show the real effect of the prompt:

audio_features = extract_audio_input_features(
    audio,
    whisper_processor,
    device,
    torch_dtype
)
whisper_model.generation_config.forced_decoder_ids = None
predicted_ids = whisper_model.generate(
    audio_features,
    language="fr",
    task="transcribe",
    eos_token_id=whisper_model.generation_config.eos_token_id,
    pad_token_id=whisper_model.generation_config.pad_token_id,
    max_time=5,
    do_sample=False,
    return_dict_in_generate=False,
    prompt_ids=prompt_ids
)

print(whisper_processor.decode(predicted_ids[0]))

“Non, ça va, pour l’instant ça va, ils sont choqués mais ça va.”

Here is the original transcription, without the prompt (also without any random sampling):
“Non, ça va. Pour l’instant, ça va. Ils sont pour chocrer, mais ça va.”
(Error on word “choqués”)

I’m open to remarks or ways to improve this.

1 Like