Set sampling_rate in wav2vec 2.0 processor

Hi there,

I’ve been getting wav2vec 2.0 up and running locally following the example code for facebook/wav2vec2-base-960h

from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer


librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")

model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")

def map_to_pred(batch):
    input_values = processor([b["array"] for b in batch["audio"]], return_tensors="pt", padding="longest").input_values
    with torch.no_grad():
        logits = model(input_values.to("cuda")).logits

    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    batch["transcription"] = transcription
    return batch

result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])

print("WER:", wer(result["text"], result["transcription"]))

However, when I run the code, I get the following warning on repeat:

It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.

I’d like to do what it says by providing the sampling_rate to the Wav2Vec2Processor but when I instead build the Wav2Vec2FeatureExtractor by explicitly providing the sampling_rate (and other arguments) as shown below:

tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h')

feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, sampling_rate=16000, padding_value=0.0, return_attention_mask=False,
    do_normalize=True, padding_size="right"
)

processor = Wav2Vec2Processor(feature_extractor, tokenizer)

this does not make the warning go away.

Even if I try to use warnings.catch_warnings() to suppress the warning it doesn’t seem to help either, although this seems like bad practice anyway.

I also found this stackexchange issue and this kaggle notebook containing the same problem.

I have been looking at the source code for Wav2Vec2FeatureExtractor as well, and while I can see where the warning is being raised I can’t see where the code is going wrong.

I would be happy to provide a PR, but so far it seems I can’t fix the issue either…

Thanks

2 Likes