Using num_proc>1 in Dataset.map hangs

I’m trying to process audio data faster (librispeech) using multi-processing. Serially everything works, but when I try num_proc>1 in the .map function the progress bar just sits at 0.

Reproducer:

from datasets import load_dataset
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-tiny")
ds = load_dataset("librispeech_asr", "clean", split="test")

ds = ds.map(lambda batch: {"input_features": feature_extractor(batch["audio"]["array"], sampling_rate=16_000, return_tensors="pt").input_features}, num_proc=2)

Ctrl+C backtrace:

File ~/workspace/sdks/venv/poplar_sdk-ubuntu_20_04-3.3.0+1401-54633455e9/3.3.0+1401_poptorch/lib/python3.8/site-packages/datasets/arrow_dataset.py:580, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    578     self: "Dataset" = kwargs.pop("self")
    579 # apply actual function
--> 580 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    581 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    582 for dataset in datasets:
    583     # Remove task templates if a column mapping of the template is no longer valid

File ~/workspace/sdks/venv/poplar_sdk-ubuntu_20_04-3.3.0+1401-54633455e9/3.3.0+1401_poptorch/lib/python3.8/site-packages/datasets/arrow_dataset.py:545, in transmit_format.<locals>.wrapper(*args, **kwargs)
    538 self_format = {
    539     "type": self._format_type,
    540     "format_kwargs": self._format_kwargs,
    541     "columns": self._format_columns,
    542     "output_all_columns": self._output_all_columns,
    543 }
    544 # apply actual function
--> 545 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    546 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    547 # re-apply format to the output

File ~/workspace/sdks/venv/poplar_sdk-ubuntu_20_04-3.3.0+1401-54633455e9/3.3.0+1401_poptorch/lib/python3.8/site-packages/datasets/arrow_dataset.py:3180, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3172 logger.info(f"Spawning {num_proc} processes")
   3173 with logging.tqdm(
   3174     disable=not logging.is_progress_bar_enabled(),
   3175     unit=" examples",
   (...)
   3178     desc=(desc or "Map") + f" (num_proc={num_proc})",
   3179 ) as pbar:
-> 3180     for rank, done, content in iflatmap_unordered(
   3181         pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
   3182     ):
   3183         if done:
   3184             shards_done += 1

File ~/workspace/sdks/venv/poplar_sdk-ubuntu_20_04-3.3.0+1401-54633455e9/3.3.0+1401_poptorch/lib/python3.8/site-packages/datasets/utils/py_utils.py:1354, in iflatmap_unordered(pool, func, kwargs_iterable)
   1351                 break
   1352 finally:
   1353     # we get the result in case there's an error to raise
-> 1354     [async_result.get(timeout=0.05) for async_result in async_results]

File ~/workspace/sdks/venv/poplar_sdk-ubuntu_20_04-3.3.0+1401-54633455e9/3.3.0+1401_poptorch/lib/python3.8/site-packages/datasets/utils/py_utils.py:1354, in <listcomp>(.0)
   1351                 break
   1352 finally:
   1353     # we get the result in case there's an error to raise
-> 1354     [async_result.get(timeout=0.05) for async_result in async_results]

File ~/workspace/sdks/venv/poplar_sdk-ubuntu_20_04-3.3.0+1401-54633455e9/3.3.0+1401_poptorch/lib/python3.8/site-packages/multiprocess/pool.py:767, in ApplyResult.get(self, timeout)
    765 self.wait(timeout)
    766 if not self.ready():
--> 767     raise TimeoutError
    768 if self._success:
    769     return self._value

Could related to this: Load_dataset hangs with local files - #4 by lhoestq
I’m doing this inside an IPython session…

Tried in main and it still hangs. However I found the the real reason for the hang seems to be torch threads.
Setting torch.set_num_threads(1) made it work. :smiling_face:

6 Likes

Hi,

I am facing the same problem. I tried it in main and also set torch.set_num_threads(1) but it still does not work:

 Processing Data..... (num_proc=10):   0%|          | 0/520 [00:00<?, ? examples/s]

Reproducer:

      df = pd.read_csv(self.args["csv_input"])
      ds = Dataset.from_pandas(df)

      ds = ds.map(
        self.process_data,
        num_proc=self.args["num_workers"],
        with_indices=True,
        batched=True,
        batch_size=int(len(df) / self.args["num_workers"]),
        load_from_cache_file=True,
        desc="Processing Data.....",
    )

Packages:
Python: 3.10.13, and
datasets-2.15.1.dev0

Any ideas on how to solve this issue? thanks a million

@wjassim hey i am facing the exact same issue, is there a possible solution to this ?

Hi, unfortunately I have not any solution. Not sure what the main problem is.

Hev I’ve noticed using torch.setnum_threads(1) in the beginning of the code solves the issue.

I see, I will try it out. Thanks a million…