Hi there,
I’ve been getting wav2vec 2.0 up and running locally following the example code for facebook/wav2vec2-base-960h
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
from jiwer import wer
librispeech_eval = load_dataset("librispeech_asr", "clean", split="test")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to("cuda")
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
def map_to_pred(batch):
input_values = processor([b["array"] for b in batch["audio"]], return_tensors="pt", padding="longest").input_values
with torch.no_grad():
logits = model(input_values.to("cuda")).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)
batch["transcription"] = transcription
return batch
result = librispeech_eval.map(map_to_pred, batched=True, batch_size=1, remove_columns=["audio"])
print("WER:", wer(result["text"], result["transcription"]))
However, when I run the code, I get the following warning on repeat:
It is strongly recommended to pass the ``sampling_rate`` argument to this function. Failing to do so can result in silent errors that might be hard to debug.
I’d like to do what it says by providing the sampling_rate to the Wav2Vec2Processor
but when I instead build the Wav2Vec2FeatureExtractor
by explicitly providing the sampling_rate
(and other arguments) as shown below:
tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('facebook/wav2vec2-base-960h')
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size=1, sampling_rate=16000, padding_value=0.0, return_attention_mask=False,
do_normalize=True, padding_size="right"
)
processor = Wav2Vec2Processor(feature_extractor, tokenizer)
this does not make the warning go away.
Even if I try to use warnings.catch_warnings()
to suppress the warning it doesn’t seem to help either, although this seems like bad practice anyway.
I also found this stackexchange issue and this kaggle notebook containing the same problem.
I have been looking at the source code for Wav2Vec2FeatureExtractor
as well, and while I can see where the warning is being raised I can’t see where the code is going wrong.
I would be happy to provide a PR, but so far it seems I can’t fix the issue either…
Thanks