To whom it may concern,
First, here is some context. As part of a research project, I want to distill the Whisper model on a dataset of non-native English speakers. By distillThis text will be hiddening Whisper, the goal is to create a more streamlined and cost-efficient model for speech recognition while maintaining high-performance standards. Because I am using a custom loss, as generate
doesn’t propagate the gradient and because I need more control on teacher-forcing, I decided to use the forward
method of Whisper.
After having taken a look at the forward
documentation for WhisperDecoderLayer, it seems that setting model.config.forced_decoder_ids
for the task, language, and BOS tokens doesn’t affect the forward
method. For the moment, I decided to manually add these tokens using the code in Snippet 2.
To test if I could match the behavior of forward
with generate
, I tried to use greedy search using teacher-forcing based on the output of generate
. If they don’t match, we will know that my implementation of greedy search doesn’t match the one from generate
.
According to my results in Snippet 3, the results don’t match. Therefore, my reasoning must be wrong.
My questions are thus:
- Is my understanding of teacher-forced predictions correct with respect to Snippet 1?
- If yes, how should I change my code to match the output from
generate
?
Thank you very much in advance for your time.
Yours sincerely,
Tony
Code snippet 1:
model.forward(input_features=input_features,
decoder_input_ids=labels_with_prompt)
Code snippet 2:
from typing import Tuple
import torch
from transformers import WhisperTokenizer
BOS_TOKEN_ID = 50258
def get_labels_with_prompt(labels: torch.Tensor,
tokenizer: WhisperTokenizer,
language: str = "en",
task: str = "transcribe",
no_timestamps: bool = True)-> Tuple[torch.Tensor, int, int]:
"""
Returns the labels with the prefix and suffix tokens, as well as the number of prefix and suffix tokens.
`labels_with_prompt` should be used as the `decoder_input_ids` argument for the `forward` method of the model.
Note: n_prefix_tokens should be 4 (BOS, language, task, if_timestamps) and n_suffix_tokens should be 1 (EOS).
"""
# Get batch size:
batch_size = labels.shape[0]
# Get prefix tokens:
forced_decoder_ids = tokenizer.get_decoder_prompt_ids(language=language, task=task, no_timestamps=no_timestamps) # language, task, if_timestamps
prefix_tokens = torch.IntTensor([BOS_TOKEN_ID] + [token_id for idx, token_id in forced_decoder_ids]) # (n_prefix_tokens, )
prefix_tokens = prefix_tokens.expand(batch_size, -1) # (batch_size, n_prefix_tokens)
# Get suffix tokens:
suffix_tokens = torch.IntTensor([tokenizer.eos_token_id]) # (n_suffix_tokens, )
suffix_tokens = suffix_tokens.expand(batch_size, -1) # (batch_size, n_suffix_tokens)
# Get prefix and suffix lengths:
n_prefix_tokens = prefix_tokens.shape[1] # n_prefix_tokens
n_suffix_tokens = suffix_tokens.shape[1] # n_suffix_tokens
# Send tensors to the same device as the `labels` tensor:
prefix_tokens = prefix_tokens.to(labels.device)
suffix_tokens = suffix_tokens.to(labels.device)
# Concatenate the prefix tensor with the original tensor along the second dimension:
labels_with_prompt = torch.cat((prefix_tokens, labels, suffix_tokens), dim=1) # (batch_size, n_tokens_labels + n_prefix_tokens + n_suffix_tokens)
return labels_with_prompt, n_prefix_tokens, n_suffix_tokens
Snippet 3:
#!/usr/bin/env python
# coding: utf-8
import torch
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from datasets import load_dataset
from trainer.prompting import get_labels_with_prompt
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe") # type: ignore
model.config.suppress_tokens = []
normalizer = processor.tokenizer._normalize
## Load dataset
# load dummy dataset and read audio files
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
x = next(iter(ds))
label = normalizer(x["text"]) # normalize label
input_features = processor(x["audio"]["array"], sampling_rate=x["audio"]["sampling_rate"], return_tensors="pt").input_features
## Tokenize the labels for teacher-forcing
tokenized_label = torch.LongTensor(processor.tokenizer(label, add_special_tokens=False).input_ids)
# Add batch dim:
tokenized_labels = tokenized_label[None, :]
## Add prompts to teacher-forced labels
processor.tokenizer.get_decoder_prompt_ids(language=None, task=None)
labels_with_prompt, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
labels=tokenized_labels, language="english", task="transcribe", tokenizer=processor.tokenizer)
processor.tokenizer.batch_decode(labels_with_prompt, skip_special_tokens=False, normalize=False)
## Predict
## Teacher-forced from greedy search
# Generate with greedy search - vanilla
pred_gen_raw = model.generate(inputs=input_features)
pred_gen_str = processor.tokenizer.batch_decode(pred_gen_raw, skip_special_tokens=True, normalize=True)
pred_gen = torch.LongTensor(processor.tokenizer.encode(pred_gen_str[0], add_special_tokens=False))[None, :]
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)
pred_gen_with_prompts, n_prefix_tokens_labels, n_suffix_tokens_labels = get_labels_with_prompt(
labels=pred_gen, language=None, task=None, tokenizer=processor.tokenizer)
output = model.forward(input_features=input_features,
decoder_input_ids=pred_gen_with_prompts)
logits = output.logits
pred_ids = torch.argmax(logits, dim=-1)
processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=False, normalize=False)
## With `generate`
# Generate with greedy search - vanilla
pred_gen = model.generate(inputs=input_features)
processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=False, normalize=False)
## Comparison
print(processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=False))
print(processor.tokenizer.batch_decode(pred_gen, skip_special_tokens=True, normalize=False))
Where the 2 last rows print outputs:
>> [' Mrister Quter is theive apostle of thethe middle classes and we are glad toto welcome his gospel.']
>> [' Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel.']
System Info
-
transformers
version: 4.30.2 - Platform: macOS-13.4-arm64-arm-64bit
- Python version: 3.10.11
- Huggingface_hub version: 0.15.1
- Safetensors version: 0.3.1
- PyTorch version (GPU?): 2.0.0 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?:
- Using distributed or parallel set-up in script?: