Disparity between output from `forward` and `generate` for greedy search (using Whisper)

To whom it may concern,

First, here is some context. As part of a research project, I want to distill the Whisper model on a dataset of non-native English speakers. By distillThis text will be hiddening Whisper, the goal is to create a more streamlined and cost-efficient model for speech recognition while maintaining high-performance standards. Because I am using a custom loss, as generate doesn’t propagate the gradient and because I need more control on teacher-forcing, I decided to use the forward method of Whisper.

After having taken a look at the forward documentation for WhisperDecoderLayer, it seems that setting model.config.forced_decoder_ids for the task, language, and BOS tokens doesn’t affect the forward method. For the moment, I decided to manually add these tokens using the code in Snippet 2.

To test if I could match the behavior of forward with generate, I tried to use greedy search using teacher-forcing based on the output of generate. If they don’t match, we will know that my implementation of greedy search doesn’t match the one from generate.

According to my results in Snippet 3, the results don’t match. Therefore, my reasoning must be wrong.

My questions are thus:

  1. Is my understanding of teacher-forced predictions correct with respect to Snippet 1?
  2. If yes, how should I change my code to match the output from generate?

Thank you very much in advance for your time.

Yours sincerely,

Tony

Code snippet 1:

model.forward(input_features=input_features,
              decoder_input_ids=labels_with_prompt)

Code snippet 2:

from typing import Tuple
import torch
from transformers import WhisperTokenizer


BOS_TOKEN_ID = 50258


def get_labels_with_prompt(labels: torch.Tensor,
                           tokenizer: WhisperTokenizer,
                           language: str = "en",
                           task: str = "transcribe",
                           no_timestamps: bool = True)-> Tuple[torch.Tensor, int, int]:
    """
    Returns the labels with the prefix and suffix tokens, as well as the number of prefix and suffix tokens.
    `labels_with_prompt` should be used as the `decoder_input_ids` argument for the `forward` method of the model.
    
    Note: n_prefix_tokens should be 4 (BOS, language, task, if_timestamps) and n_suffix_tokens should be 1 (EOS).
    """
    
    # Get batch size:
    batch_size = labels.shape[0]

    # Get prefix tokens:
    forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task=task, no_timestamps=no_timestamps)  # language, task, if_timestamps
    prefix_tokens = torch.IntTensor([BOS_TOKEN_ID] + [token_id for idx, token_id in forced_decoder_ids])  # (n_prefix_tokens, )
    prefix_tokens = prefix_tokens.expand(batch_size, -1)  # (batch_size, n_prefix_tokens)

    # Get suffix tokens:
    suffix_tokens = torch.IntTensor([tokenizer.eos_token_id])  # (n_suffix_tokens, )
    suffix_tokens = suffix_tokens.expand(batch_size, -1)  # (batch_size, n_suffix_tokens)

    # Get prefix and suffix lengths:
    n_prefix_tokens = prefix_tokens.shape[1]  # n_prefix_tokens
    n_suffix_tokens = suffix_tokens.shape[1]  # n_suffix_tokens
    
    # Send tensors to the same device as the `labels` tensor:
    prefix_tokens = prefix_tokens.to(labels.device)
    suffix_tokens = suffix_tokens.to(labels.device)
    
    # Concatenate the prefix tensor with the original tensor along the second dimension:
    labels_with_prompt = torch.cat((prefix_tokens, labels, suffix_tokens), dim=1)  # (batch_size, n_tokens_labels + n_prefix_tokens + n_suffix_tokens)

    return labels_with_prompt, n_prefix_tokens, n_suffix_tokens
 

Snippet 3:

#!/usr/bin/env python
# coding: utf-8

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from trainer.prompting import get_labels_with_prompt


model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

normalizer = processor.tokenizer._normalize


## Load dataset
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
x = next(iter(ds))

label = normalizer(x["text"])  # normalize label
input_features = processor(x["audio"]["array"], sampling_rate=x["audio"]["sampling_rate"], return_tensors="pt").input_features

## Tokenize the labels for teacher-forcing
tokenized_label = torch.LongTensor(processor.tokenizer(label, add_special_tokens=False).input_ids)

# Add batch dim:
tokenized_labels = tokenized_label[None, :]


## Add prompts to teacher-forced labels
processor.tokenizer.get_decoder_prompt_ids(language=None, task=None)

labels_with_prompt, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=tokenized_labels, language="english", task="transcribe", tokenizer=processor.tokenizer)

processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)


## Predict
## Teacher-forced from greedy search

# Generate with greedy search - vanilla
pred_gen_raw = model.generate(inputs=input_features)
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_raw, skip_special_tokens=True, normalize=True)
pred_gen = torch.LongTensor(processor.tokenizer.encode(pred_gen_str[0], add_special_tokens=False))[None, :]

processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)

pred_gen_with_prompts, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=pred_gen, language=None, task=None, tokenizer=processor.tokenizer)

output = model.forward(input_features=input_features,
                       decoder_input_ids=pred_gen_with_prompts)
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)

processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)


## With `generate`

# Generate with greedy search - vanilla
pred_gen = model.generate(inputs=input_features)
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)


## Comparison
print(processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=False))
print(processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=True, normalize=False))

Where the 2 last rows print outputs:

>> [' Mrister Quter is theive apostle of thethe middle classes and we are glad toto welcome his gospel.']
>> [' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']

System Info

  • transformers version: 4.30.2
  • Platform: macOS-13.4-arm64-arm-64bit
  • Python version: 3.10.11
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Hey,

I know this is rather old post, but I bumped into it when trying to understand how to perform only 1 Whisper Decoder step (overall how its forward method works), which took me longer than I would’ve liked. Therefore I’ve decided to post my understanding here. I am using WhisperForConditionalGeneration.

The theory

Whisper is an encoder-decoder model with attention. Encoder creates an inner representation of different parts of the input (here sound). Decoder then attends over these representations and sequentially predicts output, taking its own past outputs as the input in every consecutive step.

As per my understanding, the forward method of the model encodes the audio and does one decoder step based on the decoder_input_ids you provide. In order to generate the whole sequence, one needs to “classify” its output and feed the model/decoder again with the updated (now one token longer) decoder_input_ids.

How you “classify” the output is the place where decoding strategy comes to play – when you use Greedy, you just take the token with the highest probability and append this to the decoder input for the next step. However if you used for example BeamSearch, you would need to keep the n highest scoring beams and feed decoder with all those, again choosing the n best beams in the next step.

The generate method takes care of all that for you, iteratively feeding the decoder based on the decoding strategy until the end of sequence token (or max generation length) is reached.

My experiments

I use Czech whisper:

model_path = "mikr/whisper-large-v3-czech-cv13"
processor = WhisperProcessor.from_pretrained(model_path)
processor.tokenizer.set_prefix_tokens(task="transcribe", language="cs")
model = WhisperForConditionalGeneration.from_pretrained(model_path)
model.eval() # set model for inference

My audio is a sample from my dataset with a person saying “Pavel” (Czech name). Note that I am predicting res with model.model, not using the model.proj_out which classifies the decoder output into output vocabulary yet.

audio = dataset[1]["audio"]
in_features = processor(audio["array"], sampling_rate=audio["sampling_rate"], return_tensors="pt").input_features
prompt_ids = torch.tensor(processor.tokenizer.prefix_tokens).unsqueeze(0)
res = model.model(input_features=in_features, decoder_input_ids=prompt_ids)

In res we get last_hidden_state (which is hidden state of the last decoder layer), past_key_values (attention results), and encoder_last_hidden_state. Now let’s make vocabulary prediction on the decoder output and decode it:

lm_logits = model.proj_out(res.last_hidden_state)
pred_ids = torch.argmax(lm_logits, dim=-1)
prediction = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)
print("Model output:", pred_ids)
print("Prediction decoded:", prediction[0])
print("Gold:", dataset[1]["normalized"])

Which gives the following results:

Model output: tensor([[50283, 50360, 50364, 18968]])
Prediction decoded: <|cs|><|transcribe|><|notimestamps|>Pa
Gold: pavel

As we can see, we got one more token in the prediction. Note that we could get the logits in one step using res = model(input_features=in_features, decoder_input_ids=prompt_ids) (which calls model.forward inside, but is preferred over calling it explicitly) and receiving logits, past_key_values, and encoder_last_hidden_state, but I wanted to show the steps model does.

Now let’s make next decoder step, using encoder_last_hidden_state (we could do the whole forward again, but computing the encoder hidden states is unnecessary when we have them already) and appending the last prediction to the decoder_input_ids (note that if we used pred_ids only, we would miss the first “start of transcript” token):

next_step = model.model.decoder(input_ids=torch.concat((prompt_ids, pred_ids[:, 3:]), dim=1), encoder_hidden_states=res.encoder_last_hidden_state)
next_logits = model.proj_out(next_step.last_hidden_state)
next_pred_ids = torch.argmax(next_logits, dim=-1)
next_prediction = processor.tokenizer.batch_decode(next_pred_ids, skip_special_tokens=False, normalize=False)
print("Next step output:", next_pred_ids)
print("Next step prediction decoded:", next_prediction[0])
print("Gold:", dataset[1]["normalized"])

Which outputs:

Next step output: tensor([[50283, 50360, 50364, 18968, 779]])
Next step prediction decoded: <|cs|><|transcribe|><|notimestamps|>Pavel
Gold: pavel

Again, one token more than the last time.

Your code

I believe now you understand that because you pass decoder_input_ids=pred_gen_with_prompts to the model.forward, you are telling the model decoder to make one step based on the pred_gen_with_prompts you provide. Because pred_gen_with_prompts already contains the “end of sequence” token, I would think it does not make anything and basically just copies pred_gen_with_prompts to the output, maybe making the weird mistakes in them. You can try and decode the pred_gen_with_prompts to see what it contains.

Teacher forcing

Based on your code, I’m not sure you understand teacher forcing correctly. First and most importantly, in your Snipet 3, you are creating your pred_gen_with_prompts from pred_gen_raw, which is a label generated by the model.generate. This is not teacher forcing. Teacher forcing is when you provide the correct labels to the decoder in the training phase, so it can learn to predict the next token based on the correct ones. Therefore to use teacher forcing, you would want to create pred_gen_raw from your label = normalizer(x["text"]).

Second (but that should be clear from what I wrote already), you need to do this one decoder step after another. So if I wanted to use teacher forcing in my next_step example, I would do something like the following.

First, get the gold label tokens and see what’s the correct part of it:

gold_tokenized = processor.tokenizer(dataset[1]["normalized"])["input_ids"]
print("Gold tokenized:", gold_tokenized)
print(processor.decode(gold_tokenized[:5]))

Output:

Gold tokenized: [50258, 50283, 50360, 50364, 4306, 779, 50257]
<|startoftranscript|><|cs|><|transcribe|><|notimestamps|>pa

Then, convert to tensor and feed as the decoder_input_ids for the next step:

teacher_ids = torch.tensor(gold_tokenized[:5]).unsqueeze(0)
next_step = model.model.decoder(input_ids=teacher_ids, encoder_hidden_states=res.encoder_last_hidden_state)
next_logits = model.proj_out(next_step.last_hidden_state)
next_pred_ids = torch.argmax(next_logits, dim=-1)
next_prediction = processor.tokenizer.batch_decode(next_pred_ids, skip_special_tokens=False, normalize=False)
print("Next step output:", next_pred_ids)
print("Next step prediction decoded:", next_prediction[0])
print("Gold:", dataset[1]["normalized"])

Again getting the correct output:

Next step output: tensor([[50283, 50360, 50364, 18968, 779]])
Next step prediction decoded: <|cs|><|transcribe|><|notimestamps|>Pavel
Gold: pavel

I hope this helps at least one person. :slight_smile:

P.S.: Because we’re doing inference only, using with torch.no_grad(): would save memory and speed up computations, but I left it out for code clarity.

1 Like

Hey b0r3k,

Thanks so much for taking the time to write such a detailed and clear answer. It was very valuable to me :smile:

I’ll be honest, I need some time to put myself back into this before I can give you relevant insight. But the most important takeaway I can highlight is from this part of your answer:

I would think it [the forward method] does not make anything and basically just copies pred_gen_with_prompts to the output, maybe making the weird mistakes in them.

from which I now understand that the only logits of interest are those related to the newly generated token.

FYI I believe my post was quite confusing because I haven’t explained properly where did the labels came from (from a teacher whisper-medium model) nor how exactly I wanted the predictions to be made for my custom distillation loss. For all these reasons, thank you again for going through my post anyway!

Have a very nice day :slight_smile:
Tony

EDIT: I was curious and I found out this answer from a HuggingFace guy (@fxmarty). You were right, you’re not supposed to use the non-final logits obtained with forward. Although there’s probably a reason for keeping these values (I expect it’s used in generate), I still believe it’s a bit misleading considering that forward is supposed to be a single step for decoding.