I am currently trying to train a Speech2Text
model from scratch but what I am seeing during training is odd…
For some reason the word-error-rate (WER) is already quite good… 50% or less after the first step, which simply cannot be right…
The WER then increases as the model progressively returns more and more garbage.
Clearly, there is something wrong with the way I am training the model.
Can somebody help me find out what I’m doing wrong here or point me to an example that explains how to do this right?
This is how I am initializing the Seq2SeqTrainer
and the Speech2TextForConditionalGeneration
:
training_args = Seq2SeqTrainingArguments(
output_dir=out_dir,
evaluation_strategy=IntervalStrategy("steps"),
save_steps=save_and_eval_steps,
eval_steps=save_and_eval_steps,
num_train_epochs=3,
per_device_train_batch_size=16,
per_device_eval_batch_size=64,
warmup_steps=500,
weight_decay=0.01,
logging_dir=log_dir,
group_by_length=True,
load_best_model_at_end=True,
save_total_limit=2,
)
# Create the model
config = Speech2TextConfig(
return_dict=True,
sampling_rate=sampling_rate,
vocab_size=tokenizer.vocab_size,
pad_token_id=processor.tokenizer.pad_token_id,
bos_token_id=processor.tokenizer.bos_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
decoder_start_token_id=processor.tokenizer.bos_token_id,
)
model = Speech2TextForConditionalGeneration(config)
model.train()
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=Speech2TextCollator(processor=processor),
compute_metrics=partial(compute_metrics, processor),
callbacks=[EarlyStoppingCallback(early_stopping_patience=early_stopping_patience)],
)
last_checkpoint = trainer_utils.get_last_checkpoint(out_dir)
trainer.train(resume_from_checkpoint=last_checkpoint)
The other important thing is probably what the input to the model is. So, the data collator does the following:
@dataclass
class Speech2TextCollator:
def __init__(self, processor: Speech2TextProcessor):
self.processor = processor
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
inputs = [torch.Tensor(f["inputs"]) for f in features]
targets = [torch.Tensor(f["targets"]) for f in features]
# Create batches
inputs_batch = pad_sequence(inputs, batch_first=True)
targets_batch = pad_sequence(targets, batch_first=True).long()
attention_mask = pad_sequence([f["attention_mask"] for f in features], batch_first=True).long()
return dict(
input_features=inputs_batch,
attention_mask=attention_mask,
labels=targets_batch
)
Less relevant is how I am computing the WER:
def compute_metrics(processor: Speech2TextProcessor, pred):
# pred_logits = pred.predictions
pred_ids = np.argmax(pred.predictions[0], axis=-1)
pred_str = processor.batch_decode(pred_ids)
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = error_rate(targets=label_str, predictions=pred_str, tokens="words")
cer = error_rate(targets=label_str, predictions=pred_str, tokens="characters")
return {"wer": wer, "cer": cer}
The output during the training looks something like this:
{'eval_loss': 6.661579608917236, 'eval_wer': 0.5277777777777778, 'eval_cer': 0.37888198757763975, 'eval_runtime': 0.1115, 'eval_samples_per_second': 26.916, 'eval_steps_per_second': 8.972, 'epoch': 0.11}
{'eval_loss': 6.655572414398193, 'eval_wer': 0.5833333333333334, 'eval_cer': 0.40993788819875776, 'eval_runtime': 0.0951, 'eval_samples_per_second': 31.544, 'eval_steps_per_second': 10.515, 'epoch': 0.11}
{'eval_loss': 6.649582386016846, 'eval_wer': 0.6111111111111112, 'eval_cer': 0.453416149068323, 'eval_runtime': 0.0953, 'eval_samples_per_second': 31.472, 'eval_steps_per_second': 10.491, 'epoch': 0.12}
{'eval_loss': 6.643609523773193, 'eval_wer': 0.6111111111111112, 'eval_cer': 0.453416149068323, 'eval_runtime': 0.1001, 'eval_samples_per_second': 29.982, 'eval_steps_per_second': 9.994, 'epoch': 0.12}
{'eval_loss': 6.637502193450928, 'eval_wer': 0.6388888888888888, 'eval_cer': 0.4658385093167702, 'eval_runtime': 0.1108, 'eval_samples_per_second': 27.085, 'eval_steps_per_second': 9.028, 'epoch': 0.12}
{'eval_loss': 6.631495952606201, 'eval_wer': 0.6388888888888888, 'eval_cer': 0.4658385093167702, 'eval_runtime': 0.1021, 'eval_samples_per_second': 29.372, 'eval_steps_per_second': 9.791, 'epoch': 0.12}
{'eval_loss': 6.6251702308654785, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.4906832298136646, 'eval_runtime': 0.1079, 'eval_samples_per_second': 27.795, 'eval_steps_per_second': 9.265, 'epoch': 0.12}
{'eval_loss': 6.618703365325928, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.4906832298136646, 'eval_runtime': 0.1041, 'eval_samples_per_second': 28.821, 'eval_steps_per_second': 9.607, 'epoch': 0.12}
{'eval_loss': 6.612504959106445, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.4906832298136646, 'eval_runtime': 0.0946, 'eval_samples_per_second': 31.709, 'eval_steps_per_second': 10.57, 'epoch': 0.13}
{'eval_loss': 6.606250286102295, 'eval_wer': 0.6944444444444444, 'eval_cer': 0.5031055900621118, 'eval_runtime': 0.1131, 'eval_samples_per_second': 26.526, 'eval_steps_per_second': 8.842, 'epoch': 0.13}
{'eval_loss': 6.6001152992248535, 'eval_wer': 0.7222222222222222, 'eval_cer': 0.546583850931677, 'eval_runtime': 0.1143, 'eval_samples_per_second': 26.242, 'eval_steps_per_second': 8.747, 'epoch': 0.13}
{'eval_loss': 6.594000339508057, 'eval_wer': 0.7222222222222222, 'eval_cer': 0.546583850931677, 'eval_runtime': 0.1058, 'eval_samples_per_second': 28.356, 'eval_steps_per_second': 9.452, 'epoch': 0.13}
{'eval_loss': 6.587998867034912, 'eval_wer': 0.7777777777777778, 'eval_cer': 0.6211180124223602, 'eval_runtime': 0.1148, 'eval_samples_per_second': 26.138, 'eval_steps_per_second': 8.713, 'epoch': 0.13}
{'eval_loss': 6.582004070281982, 'eval_wer': 0.8055555555555556, 'eval_cer': 0.6956521739130435, 'eval_runtime': 0.1031, 'eval_samples_per_second': 29.1, 'eval_steps_per_second': 9.7, 'epoch': 0.13}