Language detection with Whisper

The original whisper model supports dynamically detecting the language of input text, either by default as part of its model.transcribe() method or by doing something like this

mel = whisper.log_mel_spectrogram(audio).to(model.device)
_, probs = model.detect_language(mel)

It looks like the Transformers implementation supports setting the language on the WhisperTokenizer/WhisperProcessor, but I’m wondering if there’s an equivalent language detection method.

I looked around the modeling code and didn’t see anything related to language detection, although it looks like the language is set by just force decoding with a language token at the beginning of the output, and looking at the OpenAI implementation it looks like the language detection method is just taking a single decoding step and returning probabilities for all the language tokens.

This makes me think that this wouldn’t be hard to implement myself by taking one decoding step with the Transformers model and the pulling out probabilities for language tokens, but I’m wondering if there’s a better way to do this or a plan to implement an official language detection method.

Edit: Okay I kind of answered my own question while I was doing the research to ask it, and I recognize that decoding normally without forcing any language-related tokens basically means the model is doing language detection while it decodes, but I’m still wondering if there’s some specific language detection method I missed or I should just use the logits from the first decoding step.

In case anyone else comes looking to try and do the same thing, here’s what I implemented to do language detection:

def detect_language(model: WhisperForConditionalGeneration, tokenizer: WhisperTokenizer, input_features,
                    possible_languages: Optional[Collection[str]] = None) -> List[Dict[str, float]]:
    # hacky, but all language tokens and only language tokens are 6 characters long
    language_tokens = [t for t in tokenizer.additional_special_tokens if len(t) == 6]
    if possible_languages is not None:
        language_tokens = [t for t in language_tokens if t[2:-2] in possible_languages]
        if len(language_tokens) < len(possible_languages):
            raise RuntimeError(f'Some languages in {possible_languages} did not have associated language tokens')

    language_token_ids = tokenizer.convert_tokens_to_ids(language_tokens)

    # 50258 is the token for transcribing
    logits = model(input_features,
                   decoder_input_ids = torch.tensor([[50258] for _ in range(input_features.shape[0])])).logits
    mask = torch.ones(logits.shape[-1], dtype=torch.bool)
    mask[language_token_ids] = False
    logits[:, :, mask] = -float('inf')

    output_probs = logits.softmax(dim=-1).cpu()
    return [
            lang: output_probs[input_idx, 0, token_id].item()
            for token_id, lang in zip(language_token_ids, language_tokens)
        for input_idx in range(logits.shape[0])
1 Like