Wav2vec with new LM causing CPU OOM

Hi @patrickvonplaten and all other wav2vec LM users,
thanks for the new LM addition. When using the following code (also without GPU), the CPU will run out of memory. Looks like something should be freed up but isn’t. This is not happening without the LM, so I guess this has something to do with pyctcdecode. Any ideas or am I missing a step?

This code will trigger the OOM on a GPU. Maybe set the range higher if needed, but memory usage should increase anyways.

from datasets import load_dataset
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm").to("cuda")
print("* * * * * * loaded models")

for i in range(200):
    audio_sample = dataset[i]
    print(" * * * * Sample: ", i)
    inputs = processor(audio_sample["audio"]["array"], sampling_rate=audio_sample["audio"]["sampling_rate"], return_tensors="pt").to("cuda")
    with torch.no_grad():
      logits = model(**inputs).logits
    transcription = processor.batch_decode(logits.cpu().numpy()).text
    print(transcription[0].lower())

And a CPU version, will just run slower OOM :slight_smile:

from datasets import load_dataset
from transformers import Wav2Vec2ProcessorWithLM, Wav2Vec2ForCTC
import torch

dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
processor = Wav2Vec2ProcessorWithLM.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
model = Wav2Vec2ForCTC.from_pretrained("patrickvonplaten/wav2vec2-base-100h-with-lm")
print("* * * * * * loaded models")

for i in range(200):
    audio_sample = dataset[i]
    print(" * * * * Sample: ", i)
    inputs = processor(audio_sample["audio"]["array"], sampling_rate=audio_sample["audio"]["sampling_rate"], return_tensors="pt")
    with torch.no_grad():
      logits = model(**inputs).logits
    transcription = processor.batch_decode(logits.numpy()).text
    print(transcription[0].lower())

Also asked this on github.

And solved on Github, a missing pool.close() in the processor code.

1 Like