Wav2vec2.0 memory issue for basic inference

I am trying to do inference using the pretrained wav2vec2-base-960H on one audio wav file of 2:17 min.

I tried on both my laptop (32GB RAM) and a Google Colab GPU T4 machine.

I am getting OOM on both.

The error I get on my laptop sounds like this:

– that would be just 2GB of memory. My laptop has 32GB and plenty of it free at the moment of running this.

On Google Colab, it just says that the session crashed due to memory issues.

The code I’m using is the most basic inference from the > wav2vec github page example

# !pip install transformers
# !pip install datasets
import soundfile as sf
import torch
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

# load pretrained model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")

librispeech_samples_ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")

# load audio
audio_input, sample_rate = sf.read(librispeech_samples_ds[0]["file"])

# pad input values and return pt tensor
input_values = processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values  


# retrieve logits & take argmax
logits = model(input_values).logits # this is the line where OOM appears
predicted_ids = torch.argmax(logits, dim=-1)

# transcribe
transcription = processor.decode(predicted_ids[0])

The solution is mentioned in this thread: wave2vec OOM while doing inference · Issue #3359 · facebookresearch/fairseq · GitHub and it does fix the issue.

I will close this one then.