Code review: compute_metrics for WER with Wav2Vec2ProcessorWithLM

Hi!

I was running into an error using the compute_metrics function for fine-tuning a wav2vec 2.0 model in the way that this blog post recommends.

When trying to fine-tune (and evaluate) with a language model (following this blog post), I originally ran into an error with using processor.batch_decode() to recover the reference utterance from pred (for example: "hello" from pred.label_ids[0] = [1, 2, 3, 3, 4], where vocab = {"h" : 1, "e" : 2, ...}).

I’ve been able to make a compute_metrics() that works for me, but I was wondering if there’s a better way to go about generating the label_str array for wer_metric.compute() than what I’ve done below:

def compute_metrics(pred):

    pred_logits = pred.predictions
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":

        pred_str    = processor.batch_decode(pred_logits).text

        # Can't use processor.batch_decode(pred.label_ids) like below here

        # [[1, 2, 3, 3, 4, 31, 32], ... ] => [["h", "e", "l", "l", "o", "|", "</s>"], ... ]
        label_chars = [ processor.tokenizer.convert_ids_to_tokens(l) for l in pred.label_ids ]
        # [["h", "e", "l", "l", "o", "|", "</s>"], ... ] => [["hello|"], ...]
        label_str   = [ "".join([ id for id in l if id not in processor.tokenizer.unique_no_split_tokens ]) for l in label_chars ]
        # [["hello|"], ...] => [["hello"], ...]
        label_str   = [ l.replace(processor.tokenizer.word_delimiter_token, " ").strip() for l in label_str ]

    else:

        pred_logits = pred.predictions
        pred_ids = np.argmax(pred_logits, axis=-1)

        pred_str = processor.batch_decode(pred_ids)
        # we do not want to group tokens when computing the metrics
        label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

Thanks!

1 Like

@patrickvonplaten — wondering if you have any suggestions (hope you don’t mind the tag…), thanks!

Hey @fauxneticien,

Good question! I think you should do the following:

def compute_metrics(pred):

    pred_logits = pred.predictions
    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    if type(processor).__name__ == "Wav2Vec2ProcessorWithLM":
        pred_str    = processor.batch_decode(pred_logits).text
    else:   
       pred_ids = np.argmax(pred_logits, axis=-1)
       pred_str = processor.batch_decode(pred_ids)
  
  # this will always work (even if the processor has a decoder LM)
  label_str = processor.tokenizer.batch_decode(pred.label_ids, group_tokens=False)

  wer = wer_metric.compute(predictions=pred_str, references=label_str)

  return {"wer": wer}

You can always access the tokenizer from the processor :slight_smile:

Thanks @patrickvonplaten — great tip!