Hi @patrickvonplaten and all other wav2vec LM users,
thanks for the new LM addition. When using the following code (also without GPU), the CPU will run out of memory. Looks like something should be freed up but isn’t. This is not happening without the LM, so I guess this has something to do with pyctcdecode. Any ideas or am I missing a step?
This code will trigger the OOM on a GPU. Maybe set the range higher if needed, but memory usage should increase anyways.
from datasets import load_dataset
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import torch
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
print("* * * * * * loaded models")
for i in range(200):
audio_sample = dataset[i]
print(" * * * * Sample: ", i)
inputs = processor(audio_sample["audio"]["array"], sampling_rate=audio_sample["audio"]["sampling_rate"], return_tensors="pt").to("cuda")
with torch.no_grad():
logits = model(**inputs).logits
transcription = processor.batch_decode(logits.cpu().numpy()).text
print(transcription[0].lower())
And a CPU version, will just run slower OOM
from datasets import load_dataset
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import torch
dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
print("* * * * * * loaded models")
for i in range(200):
audio_sample = dataset[i]
print(" * * * * Sample: ", i)
inputs = processor(audio_sample["audio"]["array"], sampling_rate=audio_sample["audio"]["sampling_rate"], return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
transcription = processor.batch_decode(logits.numpy()).text
print(transcription[0].lower())