Hi, I have an audio data set of the following format, which has 16 kHz audio files in a one folder named “audio” and a pandas dataframe of labels with audio to label mapping.
(Code to create this data set is at the end of this post)
Question:
What is the standard way to create a dataset from this data set to train an audio classification model?
More specifically, how can I use the facebook/hubert-large-ls960-ft feature extractor to create a Dataset to train a Hubert model? I have the additional requirements of truncating/padding input size to 10 seconds, which I’ve done in the preprocess_function below.
What I tried:
import numpy as np
import os
import pandas as pd
import soundfile as sf
from datasets import Dataset, Audio
from transformers import Wav2Vec2Processor
# creating the dataset from pandas
ds = Dataset.from_pandas(labels)
ds = ds.cast_column("audio", Audio(sampling_rate=16_000))
# feature extractor
feature_extractor = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
def preprocess_function(examples):
audio_arrays = [examples['audio']['array']]
inputs = feature_extractor(
audio_arrays,
sampling_rate=16_000,
max_length=int(16_000 * 10), # 10s
truncation=True,
)
return inputs
# map the preprocessing function
ds = ds.map(preprocess_function, remove_columns='audio')
This works fine when the data set is small. But fails when there are many audio files (N~10000) in the data set due to the map operation exhausting the memory. I’m probably doing something wrong because this clearly does not align with The magic of memory mapping . What am I doing wrong? Thanks!
Code to create the data set:
# number of examples
N = 10
# labels file
labels = pd.DataFrame({
'audio': [os.path.join('audio_dir', f"{i}.wav") for i in range(N)],
'label': np.random.choice(['A', 'B'], N)
})
# save dummy audio files
os.makedirs("audio_dir", exist_ok=True)
for file_path in labels['audio']:
dummmy_audio = np.random.randn(np.random.choice(np.arange(80_000, 240_000)).astype(int)) # between 5s - 15s long
sf.write(file_path, dummmy_audio, 16_000)
Hi ! This is a good way to define a dataset for audio classification
During map, only one batch at a time is loaded in memory and passed to your preprocess_function . To use less memory you can try to reduce the writer_batch_size (default is 1,000)
Thanks @lhoestq! I think there’s something wrong here. I’ve tried with a data set size of N=10_000 and it was always crashing on colab (~13 GB RAM) even with batch_size=1.
(My code provided is reproducible in the Colab free version with N=10000).
Another observation I’ve made is that the memory usage increases somewhat linearly when ds.map() is called. Could it be that it’s not garbage collecting?
Thanks @lhoestq, unfortunately, it’s the same even when I try with the smallest possible values for N=10000. Could it be that I’m making some mistake somewhere else in my code (I mean the provided minimal example).
I just created this reproducible example for colab. But I get this issue on a larger data set on another machine with 16 GB RAM - I think 16 GB would be enough given that the generators aren’t supposed to process in memory.
Can you check ds.cache_files ? Since you loaded the dataset from memory using .from_pandas, then the dataset has no associated cache directory to save intermediate results.
To fix this you can specify cache_file_name in .map(), this way it will write the results on your disk instead of using memory
When I check ds.cache_files, that returned an empty list.
Then I’ve tried with ds = ds.map(preprocess_function, remove_columns='audio', cache_file_name='test') and it worked with no issues at all. Also then, the ds.cache_files became [{'filename': 'test'}]
Thanks a lot for your help.
If you don’t mind me asking, how did you get this?
Since you loaded the dataset from memory using .from_pandas , then the dataset has no associated cache directory to save intermediate results.
I’ve read the docs for days but was never able to figure this out.
Basically a Dataset os just a wrapper of an Arrow table. It can be a InMemoryTable, a MemoryMappedTable (i.e. from local files), or a combination of both as a ConcatenationTable. You can check the Arrow table with ds.data