I’ve been getting wav2vec 2.0 up and running locally following the example code for facebook/wav2vec2-base-960h
from datasets import load_dataset from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import torch from jiwer import wer librispeech_eval = load_dataset("librispeech_asr", "clean", split="test") model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda") processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") def map_to_pred(batch): input_values = processor([b["array"] for b in batch["audio"]], return_tensors="pt", padding="longest").input_values with torch.no_grad(): logits = model(input_values.to("cuda")).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.batch_decode(predicted_ids) batch["transcription"] = transcription return batch result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"]) print("WER:", wer(result["text"], result["transcription"]))
However, when I run the code, I get the following warning on repeat:
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
I’d like to do what it says by providing the sampling_rate to the
Wav2Vec2Processor but when I instead build the
Wav2Vec2FeatureExtractor by explicitly providing the
sampling_rate (and other arguments) as shown below:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h') feature_extractor = Wav2Vec2FeatureExtractor( feature_size=1, sampling_rate=16000, padding_value=0.0, return_attention_mask=False, do_normalize=True, padding_size="right" ) processor = Wav2Vec2Processor(feature_extractor, tokenizer)
this does not make the warning go away.
Even if I try to use
warnings.catch_warnings() to suppress the warning it doesn’t seem to help either, although this seems like bad practice anyway.
I have been looking at the source code for
Wav2Vec2FeatureExtractor as well, and while I can see where the warning is being raised I can’t see where the code is going wrong.
I would be happy to provide a PR, but so far it seems I can’t fix the issue either…