Adding custom vocabularies on Whisper

Hi,

I’ve been conducting some ASR tests using Whisper and it shows a very decent performance, specially in English (which is my main use case). However, it sometimes fails at recognizing uncommon terms such as entities or acronyms. For instance, when a speaker says:

...I hold access to SDRs...

The transcription looks like:

...I hold access to as the ours...

Question is - how can I tune the Whisper components (i.e. tokenizer, processor, model and so on) so these specific-domain terms or acronyms get better recognized? I mean, is there a way to specify Whisper that “SDR” is a perfectly likely token? If so, what would be the best approach to handle this?

Thanks for your help!

Hey @sebasarango1180,

If you have a corpus of paired audio-text data with examples of such terms/entities/acronyms, you could experiment with fine-tuning the Whisper model on this dataset and seeing whether this improves downstream ASR performance on this distribution of data. To do so, you can follow the blog post at Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers.

Since we’re fine-tuning for English-only (rather than multilingual), we need to make two modifications to the recipe outlined in the blog post:

  1. Use and English-only checkpoint (e.g. small.en instead of small)
  2. Omit the language and task args when we instantiate the processor:
processor = WhisperProcessor.from_pretrained("openai/whisper-small.en")  # previously we set language="hi" and task="transcribe" -> we omit these args for English ASR

I would provisionally try fine-tuning just the model weights, leaving the feature extractor and tokenizer as they come from the pre-trained OpenAI checkpoint. There shouldn’t be a need to change the feature extractor in any circumstance - this component simply converts the raw audio waveform to a log-Mel spectrogram (see section Load WhisperFeatureExtractor).

The tokenizer has an extensive byte-pair vocabulary of ~50k sub-word tokens that can be used to form any word in the English language. I would first try leveraging the pre-trained tokenizer from OpenAI without any modifications. Note that this tokenizer won’t have any specific terms/entities/acronyms in its vocabulary, but will be able to form them from sub-word tokens (e.g. individual characters). Whilst this might be sub-optimal in terms of predicting the expected acronyms (SDR is composed of three sub-word tokens of S, D and R), it does mean that we can leverage all of the pre-trained weights from the OpenAI model directly. As soon as we change the tokenizer, such as by adding extra vocabulary items, we change the dimensionality of our final classification layer, and thus randomly initialise some proportion of the weights.

One other reason I believe using sub-word tokens to predict acronyms should work is because sub-word tokens more closely reflect the phonetic sounds of the audio. For example, when we say “SDR”, we don’t pronounce this as a single word, bur rather say each of the letters individually (“ESS DEE ARR”). Thus, our model should be able to predict the individual tokens for each letter (S D R) when conditioned on the acoustic information.

If this fails, we can experiment with adding the vocabulary items to the tokenizer and resizing the embedding layer. Note that this approach will only work if we have a corpus of data to train on: since we randomly initialise the new embedding weights, we’ll need to train the model in order for it to generate sensible predictions.

from transformers import WhisperTokenizer, WhisperForConditionalGeneration

# load pre-trained tokenizer and model
ckpt = "openai/whisper-small.en"
tokenizer = WhisperTokenizer.from_pretrained(ckpt)
model = WhisperForConditionalGeneration.from_pretrained(ckpt)

# define new tokens to add to vocab
new_tokens = ["SDR", ...]

# check if the new tokens are already in the vocabulary
new_tokens = set(new_tokens) - set(tokenizer.vocab.keys())

# add the tokens to the tokenizer vocabulary
tokenizer.add_tokens(list(new_tokens))

# add new random embeddings for the appended tokens
model.resize_token_embeddings(len(tokenizer))

Supposing you don’t have paired audio-text data for fine-tuning, we could explore using an initial_prompt to boost the log-probs for certain vocab items, as is done in the ‘official’ OpenAI implementation. See prompt vs prefix in DecodingOptions · Discussion #117 · openai/whisper · GitHub and whisper/transcribe.py at 0f39c89d9212e4d0c64b915cf7ba3c1f0b59fecc · openai/whisper · GitHub for info.

5 Likes

Instead of adding new tokens to the vocabulary and resizing embedding layer with random initialization, would it make sense to replace least used N tokens with new tokens?

@sanchit-gandhi Hello Sanchit. Can you clarify one thing further. How do we understand that there are new tokens in the dataset and we need to add them to the tokenizer? As far as I understand, if there are new tokens in the dataset and we will retrain the model without them, there may be problems when calculating the metric. The metric will be calculated taking into account unknown tokens. Consequently, the model will not be able to approximate the knowledge from the dataset completely.

Here are 2 other approaches.
No training required, so I highly recommend trying this before fine-tuning models or changing their architecture.

1. Initial Prompt

You can simply use the parameter initial_prompt to create a bias towards your vocabulary.
In your example, you could write: "Let's talk about International Monetary Fund and SDRs."
This will encourage the model to repeat the term SDRs and other terms related to finances.


or…

2. Suppress Tokens

Sometimes whisper keeps using a wrong word. It that’s the case, you may suppress that token.

For example, let’s pretend there’s a Latin name “Esthear” and whisper transcribes to “I hold access to Esthear’s…”
Pretend this name is represented by tokens:
("Esthe", "ar")(98765, 12345)

If you suppress the token “Esthe”, Whisper will need to come up with alternatives to transcribe your audio… And hopefully guessing “SDR” correctly.

But be careful not to suppress common tokens. If you suppress both tokens “Esthe” + “ar”, it might impact other words, like “mo-net-ar-y”, tow-ar-ds".


Code Example

initial_prompt = "Let's talk about International Monetary Fund and SDRs."
model.transcribe(audio_file, initial_prompt=initial_prompt, suppress_tokens=[98765]