Using Webdatasets to stream data

I鈥檓 trying to train a model on this dataset: MLCommons/unsupervised_peoples_speech 路 Datasets at Hugging Face.

I鈥檓 using WebDataset to iterate over the tar files using a brace expansion. This is basically a wrapper on top of torch鈥檚 IterableDataset. The problem is that if I set up more than 1 worker in the loader I get the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.10/multiprocessing/process.py", line 314, in _bootstrap
    self.run()
  File "/usr/lib/python3.10/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/home/rafael/unsupervised_peoples_speech/data_wrangling/multi_VAD.py", line 57, in producer
    for i in islice(dataset,0,128):
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/pipeline.py", line 70, in iterator
    yield from self.iterator1()
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/filters.py", line 397, in _to_tuple
    for sample in data:
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/tariterators.py", line 219, in group_by_keys
    for filesample in data:
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/tariterators.py", line 190, in tar_file_expander
    if handler(exn):
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/filters.py", line 86, in reraise_exception
    raise exn
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/tariterators.py", line 177, in tar_file_expander
    for sample in tar_file_iterator(
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/tariterators.py", line 149, in tar_file_iterator
    if handler(exn):
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/filters.py", line 86, in reraise_exception
    raise exn
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/tariterators.py", line 142, in tar_file_iterator
    data = stream.extractfile(tarinfo).read()
  File "/usr/lib/python3.10/tarfile.py", line 689, in read
    b = self.fileobj.read(length)
  File "/usr/lib/python3.10/tarfile.py", line 526, in read
    buf = self._read(size)
  File "/usr/lib/python3.10/tarfile.py", line 534, in _read
    return self.__read(size)
  File "/usr/lib/python3.10/tarfile.py", line 564, in __read
    buf = self.fileobj.read(self.bufsize)
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/gopen.py", line 88, in read
    self.check_status()
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/gopen.py", line 68, in check_status
    self.wait_for_child()
  File "/home/rafael/unsupervised_peoples_speech/speech/lib/python3.10/site-packages/webdataset/gopen.py", line 83, in wait_for_child
    raise IOError(f"{self.args}: exit {self.status} (read) {info}")
OSError: ("(('curl -s -L https://huggingface.co/datasets/MLCommons/unsupervised_peoples_speech/resolve/main/audio/000004.tar -H Authorization:Bearer hf_ASDASDSAD',), {'shell': True, 'bufsize': 8192}): exit 6 (read) {} @ <Pipe (('curl -s -L https://huggingface.co/datasets/MLCommons/unsupervised_peoples_speech/resolve/main/audio/000004.tar -H Authorization:Bearer hf_ASDASDASD',), {'shell': True, 'bufsize': 8192})>", <webdataset.gopen.Pipe object at 0x7f897a132080>, 'pipe:curl -s -L https://huggingface.co/datasets/MLCommons/unsupervised_peoples_speech/resolve/main/audio/000004.tar -H Authorization:Bearer hf_ASDASDASDP')

I鈥檝e done some testing, and if I place a time.sleep this error doesn鈥檛 occur, which makes me think there鈥檚 a rate limiter in place. Is there a way to overcome this? Any suggestions?

Edit: I鈥檓 including the code as reference:

import torch
import torchaudio
import braceexpand 
import webdataset as wds
from itertools import islice

from utils import resample_squeeze, extract_tar_number
import multiprocessing
from io import BytesIO
import time
import json
import os
import uuid
from dotenv import load_dotenv

load_dotenv()

NUM_PROCESS=128
token  = os.environ['HF_TOKEN']
url = "https://huggingface.co/datasets/MLCommons/unsupervised_peoples_speech/resolve/main/audio/{000001..000005}.tar"
token = f'Authorization:Bearer {token}'
urls = list(braceexpand.braceexpand(url))
urls = [f"pipe:curl -s -L {url} -H {token}" for url in urls]

model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                model='silero_vad',
                                force_reload=True,
                                onnx=True)

(get_speech_timestamps,
save_audio,
read_audio,
VADIterator,
collect_chunks) = utils

def vad_process(sample: str, model):
    with torch.no_grad():
        return get_speech_timestamps(
            sample,
            model
            )
    
dataset = (wds.WebDataset(urls, nodesplitter=wds.split_by_node).
          to_tuple('mp3', '__key__','__url__', handler = wds.handlers.ignore_and_continue))

def producer(queue, dataset):
    for i in islice(dataset,0,1024):
        queue.put(i)
    for i in range(NUM_PROCESS):
        queue.put(None)
    print('Producer: Done', flush=True)

def consumer(queue):
    model, _ = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                                model='silero_vad',
                                force_reload=False,
                                onnx=True)
    cur_id = uuid.uuid4()
    while True:
        item = queue.get()
        if item is None:
            break
        try:
            audio = torchaudio.load(BytesIO(item[0]))
            waveform, duration = resample_squeeze(audio)
        except:
            continue
        result = {item[1]: {'timestamps': vad_process(waveform, model), 'duration': duration, 'tar_number': extract_tar_number(item[-1])}}
        with open(f'results/vad_results_{cur_id}.jsonl', 'a+') as f:
            f.write(json.dumps(result) + '\n')
    print('Consumer: Done', flush=True)

start = time.perf_counter()
queue = multiprocessing.Queue(maxsize=64)
results = []
consumer_processes = [multiprocessing.Process(target=consumer, args=(queue,)) for _ in range(NUM_PROCESS)]
for process in consumer_processes:
    process.start()
producer_process = multiprocessing.Process(target=producer, args=(queue, dataset))
producer_process.start()
producer_process.join()
for process in consumer_processes:
    process.join()

end = time.perf_counter()
print(f"Time: {end - start}")

Hi ! Maybe the issue is that all your workers download from the same file, are you sure wds.split_by_node works as expected here ? (it relies on torch.distributed which has its own multiprocessing mechanism IIRC)

Btw (not really related): when streaming it is important to add a retry mechanism in case there is an issue with the connection.

You can use for example:

pipe:curl --connect-timeout 30 --retry 30 --retry-delay 3 -f -s -L {url} -H {token}

And alternatively you could use the datasets library, and using streaming + .map() + a DataLoader with multiple workers, I鈥檇 be happy to provide a code example if you鈥檙e interested.

Hey @lhoestq! Thanks for your answer. We thought WebDataset would be the best option given our file structure, but if you have a code example at hand of the datasets library I could give that a test.

I also removed wds.split_by_node as I thought that could be the cause, but I鈥檓 still getting that cryptic error from curl that doesn鈥檛 really say much.

Here is some code in datasets (haven鈥檛 tested but it should be a good start)

I used a torch DataLoader for multiprocessing to make the code simpler

from datasets import load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm

model = None
curr_id = None

def consume(example):
    global model, curr_id
    if model is None:
        model, _ = torch.hub.load(
            repo_or_dir='snakers4/silero-vad',
            model='silero_vad',
            force_reload=False,
            onnx=True
        )
        cur_id = uuid.uuid4()
    waveform = example["mp3"]["array"]
    duration = len(waveform) / example["mp3"]["sampling_rate"]
    result = {example["__key__"]: {'timestamps': vad_process(waveform, model), 'duration': duration, 'tar_number': extract_tar_number(example["__url__"])}}
    with open(f'results/vad_results_{cur_id}.jsonl', 'a+') as f:
        return {"written": f.write(json.dumps(result) + '\n')}

if __name__ == "__main__":
    ds = load_dataset("MLCommons/unsupervised_peoples_speech", split="train", streaming=True)
    ds = ds.map(consume, remove_columns=list(ds.features))
    num_workers = 11 * 11
    assert ds.n_shards % num_workers == 0  # this way all the workers have the same number of shards to process
    dataloader = DataLoader(ds, num_workers=num_workers)
    for _ in tqdm(dataloader):
        pass
    print("done !")

Oh thanks @lhoestq ! I thought you had some general example, and wasn鈥檛 expecting you to actually create a specific example. While trying to use load_dataset I鈥檓 getting an error in the data loader, as some of the directories inside the tar files contain .srt files, and I believe there is some sort of aggregation function that requires all samples to contain the same keys. I thought running verification_mode=no_checks would just create the warning and continue, but every single time it鈥檚 raising an error. Do you happen to have any advice for this? (Our tar files have both mp3s as well as some srts, but we just process the mp3鈥檚 currently).

Indeed datasets require all the examples to have the same keys, since a Dataset / IterableDataset has a fixed set of features.

Specifying the full set of features in advance can fix this though, e.g.

from datasets import Audio, Features, Value

features = Features({
    "__key__": Value("string"),
    "__url__": Value("string"),
    "mp3": Audio(),
    "srt": Value("binary"),
})
ds = load_dataset(..., features=features)

This should fix the code, and this way all the examples will have the 鈥渟rt鈥 feature, with a value set to None if no .srt file is associated to the audio file

Is there any additional documentation on this? I tried what you mentioned and it seems to work (I ended up setting the mp3 to binary as audio had some dependency issues).

By the way, using datasets load_dataset method allows me to load data way faster, as I don鈥檛 need the time.sleep(), but I鈥檓 still interested in knowing whether there is a rate limiter in place (even if I include a number of retries and timeouts in the curl command I still eventually get the 鈥淐url 6: Could not resolve hostname鈥. I manage to delay this a little bit by switching HF tokens between curl calls inside the brace expansion, and that makes me think there is a limiter in place per HF token. Is there?