Datasets filter/map hangs when multithreading

Hi all, I am trying to use filter/map on a dataset in a script, and I am finding that the script hangs upon completion of the filter/map operations (the tqdm progress bar always goes up to 100%)

I am calling filter like my_dataset = my_dataset.filter(lambda example: example['image_id'] not in some_set, num_proc=32).

I tracked down where the code is hanging using faulthandler.dump_traceback(), whose output is:

Thread 0x00002b0014ec1700 (most recent call first):
  File ".../lib/python3.8/threading.py", line 306 in wait
  File ".../lib/python3.8/threading.py", line 558 in wait
  File "~/virtualenv/lib/python3.8/site-packages/tqdm/_monitor.py", line 60 in run
  File ".../lib/python3.8/threading.py", line 932 in _bootstrap_inner
  File ".../lib/python3.8/threading.py", line 890 in _bootstrap

Thread 0x00002b0af28f1740 (most recent call first):
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/popen_fork.py", line 27 in poll
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/popen_fork.py", line 47 in wait
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/process.py", line 149 in join
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/pool.py", line 729 in _terminate_pool
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/util.py", line 224 in __call__
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/pool.py", line 654 in terminate
  File "~/virtualenv/lib/python3.8/site-packages/multiprocess/pool.py", line 736 in __exit__
  File "~/virtualenv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3105 in map
  File "~/virtualenv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 528 in wrapper
  File "~/virtualenv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 563 in wrapper
  File "~/virtualenv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 3531 in filter
  File "~/virtualenv/lib/python3.8/site-packages/datasets/fingerprint.py", line 511 in wrapper
  File "~/virtualenv/lib/python3.8/site-packages/datasets/arrow_dataset.py", line 528 in wrapper
  File "my_script.py", line 145 in encode -> line where filter is called

(I believe the thread 0x00002b0014ec1700 corresponds to tqdm. When I disable tqdm, I only get the stack trace from a single thread analogous to the thread 0x00002b0af28f1740 shown above.)

So based on the stack trace from the second thread 0x00002b0af28f1740, the script stops at datasets/arrow_dataset.py at fd893098627230cc734f6009ad04cf885c979ac4 Ā· huggingface/datasets Ā· GitHub, i.e. when the pool object is destroyed by the context manager for going out of scope.

I would appreciate any help on why this occurs / how to fix. I am using python 3.8, datasets 2.11, pyarrow 10.0.1.

Thanks in advance!

I would also like to add that this is not reproducible 100% of the timeā€¦ (it worked in the first few runs which made me believe there was no issue)

Hi ! Is it still happening if you lower num_proc ?

And if you interrupt the process, do you have another stacktrace that could help ?

Hi @lhoestq !

The issue does not occur when num_proc is 1 or None.

When I Ctrl-C, there is strangely no stack trace printed, but I believe faulthandler.dump_traceback prints the stack trace of all threads.

Here is a relatively minimal example that hangs on my environment:

from datasets import load_dataset


def encode():
    qrels = load_dataset("TREC-AToMiC/AToMiC-Qrels-v0.2", split="train")

    valid_texts = set(qrels.unique('text_id'))

    texts = load_dataset("TREC-AToMiC/AToMiC-Texts-v0.2", split='train')
    texts = texts.remove_columns(['media', 'category', 'source_id', 'page_url'])
    print("START FILTER")
    texts = texts.filter(lambda example: bool(example['text_id'] in valid_texts), num_proc=16)
    print("END FILTER")


def main():
    encode()


if __name__ == "__main__":
    main()

This never gets to the print("END FILTER") callā€¦

I tried to run the script multiple times with different num_proc value but couldnā€™t reproduce (mac m2, py 3.9, datasets 2.10, multiproces 0.70.14).

Could you share more information about your machine ?

Hmm okay I actually omitted one line of code in the snippet above since I thought it was irrelevant, but Iā€™m now actually thinking that might be the cause. The line is from pyserini.pyclass import autoclass at the top of the script (which is this library).

The reason I think this import could be problematic is that this causes the jnius library to also imported. Based on this doc, using jnius along with overriding the thread.run functionality could lead to crashes. I believe the multiprocess library used by datasets.map does indeed have its own custom class analogous to threads and defines its custom run function (here if Iā€™m not mistaken?). So this could result in the threads not finishing properly and that would somehow cause the process to hang?

Also, I havenā€™t verified this myself but according to Troubles using jnius in a multiprocess environment Ā· Issue #640 Ā· kivy/pyjnius Ā· GitHub, it seems like this problem may not happen if the spawn method is used, therefore you may need to run something like multiprocessing.set_start_method('fork') since the default is spawn on MacOS.

Here is my environment info:
CentOS Linux 7 (Core)
Python 3.8.10
datasets 2.11.0
multiprocess 0.70.14

Additionally, I can also verify that I could run the above code snippet (without the import) without deadlocks over multiple runs.

After I terminate the hanging program, there is still one process from the program that is still alive (based on htop). It has the following stack trace

Thread 102043 (idle): "MainThread"
 __enter__ (multiprocess/synchronize.py:101)
get (multiprocess/queues.py:358)
worker (multiprocess/pool.py:114)
run (multiprocess/process.py:108)
_bootstrap (multiprocess/process.py:315)
_launch (multiprocess/popen_fork.py:75)
__init__ (multiprocess/popen_fork.py:19)
_Popen (multiprocess/context.py:277)
start (multiprocess/process.py:121)
_repopulate_pool_static (multiprocess/pool.py:326)
_repopulate_pool (multiprocess/pool.py:303)
__init__ (multiprocess/pool.py:212)
Pool (multiprocess/context.py:119)
map (datasets/arrow_dataset.py:3087)
wrapper (datasets/arrow_dataset.py:528)
wrapper (datasets/arrow_dataset.py:563)
filter (datasets/arrow_dataset.py:3531)
wrapper (datasets/fingerprint.py:511)
wrapper (datasets/arrow_dataset.py:528)
encode (test.py:13)

which looks like it is stuck while trying to create the pool? Not sure how to interpret this as the stack trace when the script was hanging seemed to indicate it stopped when trying to destroy the pool.

datasets actually uses multiprocess for multiprocessing. Itā€™s a library that works like multiprocessing but uses dill under the hood to serialize python objects.

You can set its start method using

import multiprocess.context as ctx
ctx._force_start_method('spawn')