I’m trying to fine-tune a Whisper model (whisper-tiny.en
, specifically) on a sub-sample of Librispeech clean (data corresponding to a random sample of 200 speakers from the train.100
split). The WER for the pre-trained model (without any fine-tuning) is ~5
but during fine-tuning, first jumps to ~16
and then slowly decreases to ~9
and does not drop below that value.
The fact that model before any training steps (evaluated using a callback) has such a low WER makes me feel the data formatting is completely fine when it is being fed into the model, but may have a mismatch in the training script itself. I’m new to audio/ASR so not sure if I’m doing something wrong here; please excuse any apparent/silly mistakes!
Here’s a small code snippet:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
processor: Any
def __call__(self, features: List[Dict[str, Union[List[int], ch.Tensor]]]) -> Dict[str, ch.Tensor]:
# split inputs and labels since they have to be of different lengths and need different padding methods
# first treat the audio inputs by simply returning torch tensors
input_features = [{"input_features": feature["input_features"]}
for feature in features]
batch = self.processor.feature_extractor.pad(
input_features, return_tensors="pt")
# get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]}
for feature in features]
# pad the labels to max length
labels_batch = self.processor.tokenizer.pad(
label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
labels_batch.attention_mask.ne(1), -100)
# if bos token is appended in previous tokenization step,
# cut bos token here as it's append later anyways
if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
labels = labels[:, 1:]
batch["labels"] = labels
return batch
def tokenize_labels(dataset, tokenizer):
def prepare_dataset(batch):
# encode target text to label ids
batch["labels"] = tokenizer(batch["text"]).input_ids
return batch
dataset_ = dataset.map(prepare_dataset,
num_proc=8,
remove_columns=["file", "speaker_id", "id", "chapter_id"])
return dataset_
# Tokenize dataset
# I pre-computed and saved features for the model, which are already present
# in the datased under 'audio'->'array'
train_ds = tokenize_labels(train_ds, tokenizer)
eval_ds = tokenize_labels(eval_ds, tokenizer)
model.freeze_encoder()
gradient_checkpointing = True
training_args = Seq2SeqTrainingArguments(
output_dir="./testing_training",
per_device_train_batch_size=train_config.batch_size,
gradient_accumulation_steps=train_config.gradient_accumulation_steps,
learning_rate=train_config.learning_rate,
weight_decay=train_config.weight_decay,
warmup_steps=500,
max_steps=train_config.epochs,
logging_steps=100,
eval_steps=100,
evaluation_strategy="steps",
gradient_checkpointing=gradient_checkpointing,
fp16=True,
per_device_eval_batch_size=train_config.batch_size // 2,
predict_with_generate=True,
generation_max_length=225,
save_strategy="no",
optim="adamw_torch",
report_to=["tensorboard"],
load_best_model_at_end=train_config.get_best,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
torch_compile=True,
dataloader_num_workers=4
)
# Init data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=model.processor)
# Define metrics (WER)
def compute_metrics(pred):
pred_ids = pred.predictions
label_ids = pred.label_ids
# replace -100 with the pad_token_id
label_ids[label_ids == -100] = model.tokenizer.pad_token_id
# we do not want to group tokens when computing the metrics
pred_str = model.tokenizer.batch_decode(pred_ids, skip_special_tokens=True, normalize=True)
label_str = model.tokenizer.batch_decode(label_ids, skip_special_tokens=True, normalize=True)
wer = 100 * metric.compute(predictions=pred_str, references=label_str)
return {"wer": wer}
# Initialize trainer
trainer = Seq2SeqTrainer(
args=training_args,
model=model
train_dataset=train_ds,
eval_dataset=eval_ds,
data_collator=data_collator,
compute_metrics=compute_metrics,
tokenizer=processor.feature_extractor,
)
# Callback to evaluate at first step
# Useful to log metric of base model
class EvaluateFirstStepCallback(TrainerCallback):
def on_step_end(self, args, state, control, **kwargs):
if state.global_step == 1:
control.should_evaluate = True
trainer.add_callback(EvaluateFirstStepCallback())
# Fine-tune model
trainer.train()
# Get metrics after model
eval_results = trainer.evaluate(eval_ds)
loss = eval_results["eval_loss"]
wer = eval_results["eval_wer"]
This are what the WER and loss curves look like:
This code is mostly based on @sanchit-gandhi 's blog-post on fine-tuning Whisper models. While the tutorial itself did not handle text normalization (since it wasn’t needed for the example dataset) I did note in some other places (and on manual inspection) that there were mismatches in upper-case/lower-case etc, which is why I added them.