Whisper fine-tuning on Librispeech makes WER worse

I’m trying to fine-tune a Whisper model (whisper-tiny.en, specifically) on a sub-sample of Librispeech clean (data corresponding to a random sample of 200 speakers from the train.100 split). The WER for the pre-trained model (without any fine-tuning) is ~5 but during fine-tuning, first jumps to ~16 and then slowly decreases to ~9 and does not drop below that value.

The fact that model before any training steps (evaluated using a callback) has such a low WER makes me feel the data formatting is completely fine when it is being fed into the model, but may have a mismatch in the training script itself. I’m new to audio/ASR so not sure if I’m doing something wrong here; please excuse any apparent/silly mistakes!

Here’s a small code snippet:


@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], ch.Tensor]]]) -> Dict[str, ch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]}
                          for feature in features]
        batch = self.processor.feature_extractor.pad(
            input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]}
                          for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(
            label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(
            labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

def tokenize_labels(dataset, tokenizer):
    def prepare_dataset(batch):
        # encode target text to label ids
        batch["labels"] = tokenizer(batch["text"]).input_ids
        return batch

    dataset_ = dataset.map(prepare_dataset,
                           num_proc=8,
                           remove_columns=["file", "speaker_id", "id", "chapter_id"])
    return dataset_

# Tokenize dataset
# I pre-computed and saved features for the model, which are already present
# in the datased under 'audio'->'array'
train_ds = tokenize_labels(train_ds, tokenizer)
eval_ds = tokenize_labels(eval_ds, tokenizer)

model.freeze_encoder()

gradient_checkpointing = True
training_args = Seq2SeqTrainingArguments(
        output_dir="./testing_training",
        per_device_train_batch_size=train_config.batch_size,
        gradient_accumulation_steps=train_config.gradient_accumulation_steps,
        learning_rate=train_config.learning_rate,
        weight_decay=train_config.weight_decay,
        warmup_steps=500,
        max_steps=train_config.epochs,
        logging_steps=100,
        eval_steps=100,
        evaluation_strategy="steps",
        gradient_checkpointing=gradient_checkpointing,
        fp16=True,
        per_device_eval_batch_size=train_config.batch_size // 2,
        predict_with_generate=True,
        generation_max_length=225,
        save_strategy="no",
        optim="adamw_torch",
        report_to=["tensorboard"],
        load_best_model_at_end=train_config.get_best,
        metric_for_best_model="wer",
        greater_is_better=False,
        push_to_hub=False,
        torch_compile=True,
        dataloader_num_workers=4
    )

    # Init data collator
    data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=model.processor)

    # Define metrics (WER)
    def compute_metrics(pred):
        pred_ids = pred.predictions
        label_ids = pred.label_ids

        # replace -100 with the pad_token_id
        label_ids[label_ids == -100] = model.tokenizer.pad_token_id

        # we do not want to group tokens when computing the metrics
        pred_str  = model.tokenizer.batch_decode(pred_ids,  skip_special_tokens=True, normalize=True)
        label_str = model.tokenizer.batch_decode(label_ids, skip_special_tokens=True, normalize=True)

        wer = 100 * metric.compute(predictions=pred_str, references=label_str)

        return {"wer": wer}

    # Initialize trainer
    trainer = Seq2SeqTrainer(
        args=training_args,
        model=model
        train_dataset=train_ds,
        eval_dataset=eval_ds,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        tokenizer=processor.feature_extractor,
    )
    
    # Callback to evaluate at first step
    # Useful to log metric of base model
    class EvaluateFirstStepCallback(TrainerCallback):
         def on_step_end(self, args, state, control, **kwargs):
            if state.global_step == 1:
                control.should_evaluate = True
    trainer.add_callback(EvaluateFirstStepCallback())

    # Fine-tune model
    trainer.train()

    # Get metrics after model
    eval_results = trainer.evaluate(eval_ds)
    loss = eval_results["eval_loss"]
    wer = eval_results["eval_wer"]

This are what the WER and loss curves look like:

This code is mostly based on @sanchit-gandhi 's blog-post on fine-tuning Whisper models. While the tutorial itself did not handle text normalization (since it wasn’t needed for the example dataset) I did note in some other places (and on manual inspection) that there were mismatches in upper-case/lower-case etc, which is why I added them.

2 Likes

Hi Anshuman,

I encountered a similar problem before. For my project, applying a lowercasing mapping function during pre-processing solved the issue. Could you try to add this step as well?

Regards,
Tony

I am experiencing the same exact issue with the Fleurs English dataset and the whisper small model: the evaluation loss decreases over time, but the WER increases + levels off.

I used the normalized transcription column via the Huggingface Fleurs Dataset, which already took care of making all the text lower-case and getting rid of punctuation.

My inclination is that the issue is related to the hyperparameters, but I have not yet resolved this. My hyperparameters are very similar to yours except I used a warmup-ratio (0.1) instead of warmup steps.

@tonywu71 If you fine-tuned an English whisper model, would you be able to share your choice of hyperparameters?

I’m afraid I only fine-tuned the multilingual version of Whisper. I usually go with a warmup ratio of 1% and a learning rate of 1e-5.

1 Like

@tonywu71 yes- adding that normalization step indeed fixed it. Thanks!

I’m glad you were able to fix this! Can you share the code snippet that you added?

The main issue was that Whisper is trained to predict in lower-case while Librispeech has all upper-case text. Since the tokenizer has different tokens for upper and lower case, this disparity just makes the
model learn to not use lower-case tokens, which is unnecessary (hence the high WERs).

The only change I made is add a .lower() to ground-truth labels before computing WER. The normalization step will then really only remove punctuation

1 Like