Language detection with Whisper

The original whisper model supports dynamically detecting the language of input text, either by default as part of its model.transcribe() method or by doing something like this

mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)

It looks like the Transformers implementation supports setting the language on the WhisperTokenizer/WhisperProcessor, but I’m wondering if there’s an equivalent language detection method.

I looked around the modeling code and didn’t see anything related to language detection, although it looks like the language is set by just force decoding with a language token at the beginning of the output, and looking at the OpenAI implementation it looks like the language detection method is just taking a single decoding step and returning probabilities for all the language tokens.

This makes me think that this wouldn’t be hard to implement myself by taking one decoding step with the Transformers model and the pulling out probabilities for language tokens, but I’m wondering if there’s a better way to do this or a plan to implement an official language detection method.

Edit: Okay I kind of answered my own question while I was doing the research to ask it, and I recognize that decoding normally without forcing any language-related tokens basically means the model is doing language detection while it decodes, but I’m still wondering if there’s some specific language detection method I missed or I should just use the logits from the first decoding step.

3 Likes

In case anyone else comes looking to try and do the same thing, here’s what I implemented to do language detection:

def detect_language(model: WhisperForConditionalGeneration, tokenizer: WhisperTokenizer, input_features,
                    possible_languages: Optional[Collection[str]] = None) -> List[Dict[str, float]]:
    # hacky, but all language tokens and only language tokens are 6 characters long
    language_tokens = [t for t in tokenizer.additional_special_tokens if len(t) == 6]
    if possible_languages is not None:
        language_tokens = [t for t in language_tokens if t[2:-2] in possible_languages]
        if len(language_tokens) < len(possible_languages):
            raise RuntimeError(f'Some languages in {possible_languages} did not have associated language tokens')

    language_token_ids = tokenizer.convert_tokens_to_ids(language_tokens)

    # 50258 is the token for transcribing
    logits = model(input_features,
                   decoder_input_ids = torch.tensor([[50258] for _ in range(input_features.shape[0])])).logits
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[language_token_ids] = False
    logits[:, :, mask] = -float('inf')

    output_probs = logits.softmax(dim=-1).cpu()
    return [
        {
            lang: output_probs[input_idx, 0, token_id].item()
            for token_id, lang in zip(language_token_ids, language_tokens)
        }
        for input_idx in range(logits.shape[0])
    ]
3 Likes

Thank you very much, it seems really interesnting. Can you share an example?
I try to execute your function buyt error arises.
Do you think is possible to create a function that translates to Spanish?
Thank You very much!

processor = WhisperProcessor.from_pretrained(“openai/whisper-tiny.en”)
model = WhisperForConditionalGeneration.from_pretrained(“openai/whisper-tiny.en”)
tokenizer = WhisperTokenizer.from_pretrained(“openai/whisper-tiny.en”)
input_features = {“Hola”,“como”,“estás”}
detect_language(model, tokenizer,input_features)

Whisper is a model for processing audio, so your input features need to be audio that has been processed by the WhisperProcessor. For me this looks something like this:

    waveform, sample_rate = torchaudio.load(str(audio_path))
    input_features = processor(waveform.squeeze().numpy(), sampling_rate=sample_rate,
                               return_tensors="pt").input_features

    language = detect_language(model, tokenizer, input_features, {'en', 'zh'})

For translating to Spanish, you’ll want to use a translation model, like MarianMT

Clear! Thank you very much! I thougth Whisper could work with input of text too. I thought that somehow in the logic of the algorithm, first there should be a phase of converting the audio to text and in a second phase, the detection or translation of the text.

How can we use this in a scenario where there is a speech is multi lingual, so language keep changing during the whole speech

1 Like

That’s a tricky question. You could try something like splitting the speech up into segments and generating language tokens for each segment, but unfortunately I don’t know a silver bullet solution for this.

Whisper does a decent job recognizing multilingual speech out of the box (and even better with fine-tuning) though, so you could also just try and recognize all the speech and then segment the text based on language. That might be easier.

Is there a way that I can limit whisper’s language identification choices? My audio files contain three languages. Currently, I am chunking my audio files in 3 seconds, and feeding to whisper and getting the language ID. However, It sometimes detect another language which is not in the file at all!

So, I was thinking of limiting whisper’s choice. Or somehow using whisper’s features to do some post processing to have more accurate result.

What I want to do is to determine the exact time of language switch. If u know another method that has a smaller resolution (smaller than 3 seconds) let me know

is there a way for longer audio length (5 mins)… or i need to manually split it and take the inference and then average it out across splitting?

use the pipeline . The pipeline does automatically chunk audio longer than 30 seconds (or any lower value you set) including a definable overlapping lenght.

transpipe = pipeline(
            "automatic-speech-recognition",
            model="openai/whisper-medium",
            chunk_length_s=30  # 30 seconds is the maximum length,
            device="cuda",
            stride_length_s=5 # the overlapping audio length
)

transcription = transpipe("my60minuteFile.wav")
1 Like

No I need to identify language and not run transcription, so i have 5 mins long audio, so i need 10 inferences over it.

Sorry, I was just reading your reply, from that it wasn’t clear - sorry again.