Disparity between output from `forward` and `generate` for greedy search (using Whisper)

To whom it may concern,

First, here is some context. As part of a research project, I want to distill the Whisper model on a dataset of non-native English speakers. By distillThis text will be hiddening Whisper, the goal is to create a more streamlined and cost-efficient model for speech recognition while maintaining high-performance standards. Because I am using a custom loss, as generate doesn’t propagate the gradient and because I need more control on teacher-forcing, I decided to use the forward method of Whisper.

After having taken a look at the forward documentation for WhisperDecoderLayer, it seems that setting model.config.forced_decoder_ids for the task, language, and BOS tokens doesn’t affect the forward method. For the moment, I decided to manually add these tokens using the code in Snippet 2.

To test if I could match the behavior of forward with generate, I tried to use greedy search using teacher-forcing based on the output of generate. If they don’t match, we will know that my implementation of greedy search doesn’t match the one from generate.

According to my results in Snippet 3, the results don’t match. Therefore, my reasoning must be wrong.

My questions are thus:

  1. Is my understanding of teacher-forced predictions correct with respect to Snippet 1?
  2. If yes, how should I change my code to match the output from generate?

Thank you very much in advance for your time.

Yours sincerely,

Tony

Code snippet 1:

model.forward(input_features=input_features,
              decoder_input_ids=labels_with_prompt)

Code snippet 2:

from typing import Tuple
import torch
from transformers import WhisperTokenizer


BOS_TOKEN_ID = 50258


def get_labels_with_prompt(labels: torch.Tensor,
                           tokenizer: WhisperTokenizer,
                           language: str = "en",
                           task: str = "transcribe",
                           no_timestamps: bool = True)-> Tuple[torch.Tensor, int, int]:
    """
    Returns the labels with the prefix and suffix tokens, as well as the number of prefix and suffix tokens.
    `labels_with_prompt` should be used as the `decoder_input_ids` argument for the `forward` method of the model.
    
    Note: n_prefix_tokens should be 4 (BOS, language, task, if_timestamps) and n_suffix_tokens should be 1 (EOS).
    """
    
    # Get batch size:
    batch_size = labels.shape[0]

    # Get prefix tokens:
    forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task=task, no_timestamps=no_timestamps)  # language, task, if_timestamps
    prefix_tokens = torch.IntTensor([BOS_TOKEN_ID] + [token_id for idx, token_id in forced_decoder_ids])  # (n_prefix_tokens, )
    prefix_tokens = prefix_tokens.expand(batch_size, -1)  # (batch_size, n_prefix_tokens)

    # Get suffix tokens:
    suffix_tokens = torch.IntTensor([tokenizer.eos_token_id])  # (n_suffix_tokens, )
    suffix_tokens = suffix_tokens.expand(batch_size, -1)  # (batch_size, n_suffix_tokens)

    # Get prefix and suffix lengths:
    n_prefix_tokens = prefix_tokens.shape[1]  # n_prefix_tokens
    n_suffix_tokens = suffix_tokens.shape[1]  # n_suffix_tokens
    
    # Send tensors to the same device as the `labels` tensor:
    prefix_tokens = prefix_tokens.to(labels.device)
    suffix_tokens = suffix_tokens.to(labels.device)
    
    # Concatenate the prefix tensor with the original tensor along the second dimension:
    labels_with_prompt = torch.cat((prefix_tokens, labels, suffix_tokens), dim=1)  # (batch_size, n_tokens_labels + n_prefix_tokens + n_suffix_tokens)

    return labels_with_prompt, n_prefix_tokens, n_suffix_tokens
 

Snippet 3:

#!/usr/bin/env python
# coding: utf-8

import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from trainer.prompting import get_labels_with_prompt


model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")  # type: ignore
model.config.suppress_tokens = []

normalizer = processor.tokenizer._normalize


## Load dataset
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
x = next(iter(ds))

label = normalizer(x["text"])  # normalize label
input_features = processor(x["audio"]["array"], sampling_rate=x["audio"]["sampling_rate"], return_tensors="pt").input_features

## Tokenize the labels for teacher-forcing
tokenized_label = torch.LongTensor(processor.tokenizer(label, add_special_tokens=False).input_ids)

# Add batch dim:
tokenized_labels = tokenized_label[None, :]


## Add prompts to teacher-forced labels
processor.tokenizer.get_decoder_prompt_ids(language=None, task=None)

labels_with_prompt, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=tokenized_labels, language="english", task="transcribe", tokenizer=processor.tokenizer)

processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)


## Predict
## Teacher-forced from greedy search

# Generate with greedy search - vanilla
pred_gen_raw = model.generate(inputs=input_features)
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_raw, skip_special_tokens=True, normalize=True)
pred_gen = torch.LongTensor(processor.tokenizer.encode(pred_gen_str[0], add_special_tokens=False))[None, :]

processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)

pred_gen_with_prompts, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
    labels=pred_gen, language=None, task=None, tokenizer=processor.tokenizer)

output = model.forward(input_features=input_features,
                       decoder_input_ids=pred_gen_with_prompts)
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)

processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)


## With `generate`

# Generate with greedy search - vanilla
pred_gen = model.generate(inputs=input_features)
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)


## Comparison
print(processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=False))
print(processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=True, normalize=False))

Where the 2 last rows print outputs:

>> [' Mrister Quter is theive apostle of thethe middle classes and we are glad toto welcome his gospel.']
>> [' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']

System Info

  • transformers version: 4.30.2
  • Platform: macOS-13.4-arm64-arm-64bit
  • Python version: 3.10.11
  • Huggingface_hub version: 0.15.1
  • Safetensors version: 0.3.1
  • PyTorch version (GPU?): 2.0.0 (False)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?: