Replicating the Fine-Tune Whisper For Multilingual ASR with 🤗 Transformers on MPS

Hello! I’m new to Automatic Speech Recognition and I found this incredibly helpful tutorial to Fine-Tune Whisper For Multilingual ASR with :hugs: Transformers.

I’m trying to replicate the tutorial on an M2 Ultra to fine-tune on a language not yet seen by Whisper but which is linguistically similar to about 3 of the 96 languages it was pretained on.

I, however, cannot get my model to start training and from what I can tell so far, it seems it’s got something to do with configuring the mps correctly?

Below is the script I’m running, the stack trace and the details of my system and the Transformers library version I am using:

# Setting up environment

# !pip install --upgrade pip
# !pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

# Importing BembaSpeech dataset that's been uploaded to huggingface private repo.
# Logging in to huggingface

# from huggingface_hub import notebook_login

# notebook_login()

# Importing BembaSpeech dataset that's available locally.
import pandas as pd
from datasets import load_dataset, DatasetDict, Dataset, Audio

csv_file_path = "/Users/user/Desktop/BembaSpeech-master/bem/metadata.csv"
df = pd.read_csv(csv_file_path)

df_dict = df.to_dict(orient='list')
bemba_speech = DatasetDict()

# Type casting audio column from string to audio file
bemba_speech["train"] = Dataset.from_dict(df_dict).cast_column("audio", Audio())

# Performing randomised 90:10 train-test split
bemba_speech = bemba_speech['train'].train_test_split(test_size=0.1)

#Verifying dataset uploaded correctly

print(bemba_speech)
print("Dataset Loaded.")

#Setting up the Whisper Feature Extractor

from transformers import WhisperFeatureExtractor

feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small")
print("Feature Extractor Loaded.")

#Setting up the Whisper Tokenizer

from transformers import WhisperTokenizer

tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small", task="transcribe")
print("Tokenizer Loaded.")

#Verifying that Tokenizer is recognising the language correctly

input_str = bemba_speech["train"][0]["sentence"]
labels = tokenizer(input_str).input_ids
decoded_with_special = tokenizer.decode(labels, skip_special_tokens=False)
decoded_str = tokenizer.decode(labels, skip_special_tokens=True)

print(f"Input:                 {input_str}")
print(f"Decoded w/ special:    {decoded_with_special}")
print(f"Decoded w/out special: {decoded_str}")
print(f"Are equal:             {input_str == decoded_str}")

# Note: Closest Bantu languages to Bemba: Swahili, Shona and Lingala

from transformers import WhisperProcessor

processor = WhisperProcessor.from_pretrained("openai/whisper-small", task="transcribe")
print("Whisper Processor Set Up.")
print(bemba_speech["train"][0])


def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["sentence"]).input_ids
    return batch


bemba_speech = bemba_speech.map(prepare_dataset, remove_columns=bemba_speech.column_names["train"], num_proc=1)

from transformers import WhisperForConditionalGeneration

model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

model.generation_config.task = "transcribe"

model.generation_config.forced_decoder_ids = None
print("Whisper Model Loaded.")

import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union

# device = torch.device("mps")
# model.to(device)

# print(f"Model is on device: {next(model.parameters()).device}")

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # print(f"input_features type: {type(input_features)}, shape: {[len(f) for f in input_features]}")
        # print(f"label_features type: {type(label_features)}, shape: {[len(f['input_ids']) for f in label_features]}")
        # print(f"Padded batch input_features type: {type(batch['input_features'])}, shape: {batch['input_features'].shape}")
        # print(f"Padded batch labels_batch type: {type(labels_batch['input_ids'])}, shape: {labels_batch['input_ids'].shape}")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # print(f"Labels after masked_fill type: {type(labels)}, shape: {labels.shape}")

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

import evaluate

metric = evaluate.load("wer")

def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./whisper-small-bem",  # change to a repo name of your choice
    per_device_train_batch_size=16,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=5000,
    gradient_checkpointing=True,
    fp16=False,
    evaluation_strategy="steps",
    per_device_eval_batch_size=8,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=True,
)

from transformers import Seq2SeqTrainer

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=bemba_speech["train"],
    eval_dataset=bemba_speech["test"],
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)
print("Training Arguments Defined.")

trainer.train()

/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/training_args.py:1494: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
  warnings.warn(
Training Arguments Defined.
max_steps is given, it will override any value given in num_train_epochs
  0%|          | 0/5000 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/Users/user/PycharmProjects/pythonProject/.venv/whisperbem.py", line 198, in <module>
    trainer.train()
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 1923, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2268, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3307, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 3338, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/models/whisper/modeling_whisper.py", line 1753, in forward
    outputs = self.model(
              ^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/models/whisper/modeling_whisper.py", line 1611, in forward
    encoder_outputs = self.encoder(
                      ^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/transformers/models/whisper/modeling_whisper.py", line 1174, in forward
    inputs_embeds = nn.functional.gelu(self.conv1(input_features))
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 310, in forward
    return self._conv_forward(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/user/PycharmProjects/pythonProject/.venv/lib/python3.12/site-packages/torch/nn/modules/conv.py", line 306, in _conv_forward
    return F.conv1d(input, weight, bias, self.stride,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Mismatched Tensor types in NNPack convolutionOutput
  0%|          | 0/5000 [00:01<?, ?it/s]


- `transformers` version: 4.42.4
- Platform: macOS-14.5-arm64-arm-64bit
- Python version: 3.12.0
- Huggingface_hub version: 0.23.5
- Safetensors version: 0.4.3
- Accelerate version: 0.32.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.3.1 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using distributed or parallel set-up in script?: <fill in>

I tried to run this first in Google Colab and I got to training with no major issues. The only reason I cannot complete this fine-tuning exercise in Colab is because only the free version is available in my region and the session typically times out at around 3 hours.

Thank you so much for your time!