Finetuning Whisper with prompts

Hi All,
I’m trying to finetune Whisper by resuming its pre-training task and adding initial prompts as part of the model’s forward pass. I saw this amazing tutorial, however, it does not contain a section about using prompts as part of the fine-tuning dataset.

Thanks!

2 Likes

Any news on this?

Hi @AvivSham,

I started digging into the actual code and I just realized that the Whisper tokenizer can accept two sentences as input just as models such as BERT do. For BERT-like models the two input sentences are concated and separated by a [SEP] token:

[CLS] sentence1 [SEP] sentence2 [SEP]

This behaviour is kept in the Whisper tokenizer too for API consistency issues, although it is not actually used during the finetuning process. In the current code, it simply concatenates both sentences if passed.

In order to avoid too many changes in the code, I would simply replace that line with the following code, so that it matches the format stated in the original paper:

start_of_prev_id = self.all_special_ids[-3]
return [start_of_prev_id] + token_ids_1 + self.prefix_tokens + token_ids_0 + [self.eos_token_id]

After that change, passing both the actual transciption and the prompt to the tokenizer it should return the expected format:

processor = WhisperProcessor.from_pretrained("openai/whisper-small", language="english", task="transcribe")
processor.tokenizer(" and the actual transcription is this.", "This is the prompt")

# The code above should return this
# <|startofprev|>This is the prompt<|startoftranscript|> <|transcribe|> <|en|> <|notimestamps|> and the actual transcription is this.<|endoftext|>

I hope this serves as the starting point to finetune Whisper with prompts.

1 Like

Hello! Thanks for your idea.

Now you have the correct IDs. However, how do you now pass this to the Trainer? I thought that is probably handled automatically but the more I go into the code, the more I am confused. With this, we need to split this up between, so that the Model (in your example):

  • generates with prompt_id, which are up until and without <|startoftranscript|>
  • decoder_input_ids with and up until <|notimestamps|>.

When I started to look into the trainer, which I know generates decoder_input_ids from the labels, I found this:

   if labels is not None:

      if decoder_input_ids is None and decoder_inputs_embeds is None:
          decoder_input_ids = shift_tokens_right(
              labels, self.config.pad_token_id, self.config.decoder_start_token_id
          )
def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):

    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

if pad_token_id is None:
    raise ValueError("self.model.config.pad_token_id has to be defined.")
  # replace possible -100 values in labels by `pad_token_id`
  shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

  return shifted_input_ids

from: transformers/src/transformers/models/whisper/modeling_whisper.py at main · huggingface/transformers (github.com)

However, this just made me more confused: Why does padding and shift tokens to the right make the correct decoder_input_ids? In my opinion, it should be the following (of course as IDS) and not just the labels shifted to the right with a <|startoftranscript|> added

<|startofprev|>This is the prompt<|startoftranscript|> <|transcribe|> <|en|> <|notimestamps|>

Any input to this? Thank you very much!