For anyone who is still looking for a solution.
I think I manged to make it work. Here is an example with whisper medium in French.
First extract the prompt ids with get_prompt_ids()
method:
import torch
prev_text = "est ce qu'ils sont choqués ?"
prompt_ids = torch.tensor(whisper_processor.get_prompt_ids(prev_text), device=device)
print(whisper_processor.decode(prompt_ids, skip_special_tokens=False))
“<|startofprev|> est ce qu’ils sont choqués ?”
Then inject this prompt in inference. this inference uses no sampling to remain deterministic in order to show the real effect of the prompt:
audio_features = extract_audio_input_features(
audio,
whisper_processor,
device,
torch_dtype
)
whisper_model.generation_config.forced_decoder_ids = None
predicted_ids = whisper_model.generate(
audio_features,
language="fr",
task="transcribe",
eos_token_id=whisper_model.generation_config.eos_token_id,
pad_token_id=whisper_model.generation_config.pad_token_id,
max_time=5,
do_sample=False,
return_dict_in_generate=False,
prompt_ids=prompt_ids
)
print(whisper_processor.decode(predicted_ids[0]))
“Non, ça va, pour l’instant ça va, ils sont choqués mais ça va.”
Here is the original transcription, without the prompt (also without any random sampling):
“Non, ça va. Pour l’instant, ça va. Ils sont pour chocrer, mais ça va.”
(Error on word “choqués”)
I’m open to remarks or ways to improve this.