Need help training Speech2Text from scratch

I am currently trying to train a Speech2Text model from scratch but what I am seeing during training is odd…

For some reason the word-error-rate (WER) is already quite good… 50% or less after the first step, which simply cannot be right…

The WER then increases as the model progressively returns more and more garbage.

Clearly, there is something wrong with the way I am training the model.

Can somebody help me find out what I’m doing wrong here or point me to an example that explains how to do this right?

This is how I am initializing the Seq2SeqTrainer and the Speech2TextForConditionalGeneration:

training_args = Seq2SeqTrainingArguments(
    output_dir=out_dir,
    evaluation_strategy=IntervalStrategy("steps"),
    save_steps=save_and_eval_steps,
    eval_steps=save_and_eval_steps,
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=log_dir,
    group_by_length=True,
    load_best_model_at_end=True,
    save_total_limit=2,
)

# Create the model
config = Speech2TextConfig(
    return_dict=True,
    sampling_rate=sampling_rate,
    vocab_size=tokenizer.vocab_size,
    pad_token_id=processor.tokenizer.pad_token_id,
    bos_token_id=processor.tokenizer.bos_token_id,
    eos_token_id=processor.tokenizer.eos_token_id,
    decoder_start_token_id=processor.tokenizer.bos_token_id,
)

model = Speech2TextForConditionalGeneration(config)
model.train()

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=Speech2TextCollator(processor=processor),
    compute_metrics=partial(compute_metrics, processor),
    callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)],
)

last_checkpoint = trainer_utils.get_last_checkpoint(out_dir)
trainer.train(resume_from_checkpoint=last_checkpoint)

The other important thing is probably what the input to the model is. So, the data collator does the following:

@dataclass
class Speech2TextCollator:

    def __init__(self, processor: Speech2TextProcessor):
        self.processor = processor

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        inputs = [torch.Tensor(f["inputs"]) for f in features]
        targets = [torch.Tensor(f["targets"]) for f in features]
        # Create batches
        inputs_batch = pad_sequence(inputs, batch_first=True)
        targets_batch = pad_sequence(targets, batch_first=True).long()
        attention_mask = pad_sequence([f["attention_mask"] for f in features], batch_first=True).long()
        return dict(
            input_features=inputs_batch,
            attention_mask=attention_mask,
            labels=targets_batch
        )

Less relevant is how I am computing the WER:

def compute_metrics(processor: Speech2TextProcessor, pred):
    # pred_logits = pred.predictions
    pred_ids = np.argmax(pred.predictions[0], axis=-1)
    pred_str = processor.batch_decode(pred_ids)

    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = error_rate(targets=label_str, predictions=pred_str, tokens="words")
    cer = error_rate(targets=label_str, predictions=pred_str, tokens="characters")

    return {"wer": wer, "cer": cer}

The output during the training looks something like this:

{'eval_loss': 6.661579608917236, 'eval_wer': 0.5277777777777778, 'eval_cer': 0.37888198757763975, 'eval_runtime': 0.1115, 'eval_samples_per_second': 26.916, 'eval_steps_per_second': 8.972, 'epoch': 0.11}
                                                                                
{'eval_loss': 6.655572414398193, 'eval_wer': 0.5833333333333334, 'eval_cer': 0.40993788819875776, 'eval_runtime': 0.0951, 'eval_samples_per_second': 31.544, 'eval_steps_per_second': 10.515, 'epoch': 0.11}

{'eval_loss': 6.649582386016846, 'eval_wer': 0.6111111111111112, 'eval_cer': 0.453416149068323, 'eval_runtime': 0.0953, 'eval_samples_per_second': 31.472, 'eval_steps_per_second': 10.491, 'epoch': 0.12}        

{'eval_loss': 6.643609523773193, 'eval_wer': 0.6111111111111112, 'eval_cer': 0.453416149068323, 'eval_runtime': 0.1001, 'eval_samples_per_second': 29.982, 'eval_steps_per_second': 9.994, 'epoch': 0.12}
                                                                              
{'eval_loss': 6.637502193450928, 'eval_wer': 0.6388888888888888, 'eval_cer': 0.4658385093167702, 'eval_runtime': 0.1108, 'eval_samples_per_second': 27.085, 'eval_steps_per_second': 9.028, 'epoch': 0.12}
                                                                             
{'eval_loss': 6.631495952606201, 'eval_wer': 0.6388888888888888, 'eval_cer': 0.4658385093167702, 'eval_runtime': 0.1021, 'eval_samples_per_second': 29.372, 'eval_steps_per_second': 9.791, 'epoch': 0.12}
                                                                             
{'eval_loss': 6.6251702308654785, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.4906832298136646, 'eval_runtime': 0.1079, 'eval_samples_per_second': 27.795, 'eval_steps_per_second': 9.265, 'epoch': 0.12}
                                                                       
{'eval_loss': 6.618703365325928, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.4906832298136646, 'eval_runtime': 0.1041, 'eval_samples_per_second': 28.821, 'eval_steps_per_second': 9.607, 'epoch': 0.12}
                                                                           
{'eval_loss': 6.612504959106445, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.4906832298136646, 'eval_runtime': 0.0946, 'eval_samples_per_second': 31.709, 'eval_steps_per_second': 10.57, 'epoch': 0.13}
                                                                        
{'eval_loss': 6.606250286102295, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.5031055900621118, 'eval_runtime': 0.1131, 'eval_samples_per_second': 26.526, 'eval_steps_per_second': 8.842, 'epoch': 0.13}
                                                                             
{'eval_loss': 6.6001152992248535, 'eval_wer': 0.7222222222222222, 'eval_cer': 0.546583850931677, 'eval_runtime': 0.1143, 'eval_samples_per_second': 26.242, 'eval_steps_per_second': 8.747, 'epoch': 0.13}
                                                                        
{'eval_loss': 6.594000339508057, 'eval_wer': 0.7222222222222222, 'eval_cer': 0.546583850931677, 'eval_runtime': 0.1058, 'eval_samples_per_second': 28.356, 'eval_steps_per_second': 9.452, 'epoch': 0.13}
                                                                        
{'eval_loss': 6.587998867034912, 'eval_wer': 0.7777777777777778, 'eval_cer': 0.6211180124223602, 'eval_runtime': 0.1148, 'eval_samples_per_second': 26.138, 'eval_steps_per_second': 8.713, 'epoch': 0.13}
                                                                        
{'eval_loss': 6.582004070281982, 'eval_wer': 0.8055555555555556, 'eval_cer': 0.6956521739130435, 'eval_runtime': 0.1031, 'eval_samples_per_second': 29.1, 'eval_steps_per_second': 9.7, 'epoch': 0.13}

1 Like