Dataset.map stuck with `torch.set_num_threads` set to 2 or larger

For a few days I’m trying to figure out how I can speedup inference. I stucked with num_proc Dataset.map’s parameter. Also I found that PyTorch has torch.set_num_threads(int) method. I’ve tried different combinations num_proc and torch.set_num_threads and found an issue with that: everything works fine with threads = 1 and num_proc equal 1 or 2. If I’m trying to change num_proc to 2, 3, … and set the threads count to 2 then Dataset.map stucks. I’ve waited for a hour on a really small dataset without any successs.

The code to reproduce:

from transformers import AutoModel, AutoTokenizer
from datasets import Dataset
from datasets.utils.logging import disable_progress_bar
import torch

def get_metrics_num_proc_num_threads_row(paragraph_count, batch_size, num_proc, num_threads):
    def _get_embeddings(texts):
        encoded_input = tokenizer(
            texts, padding=True, truncation=True, return_tensors='pt'
        )

        with torch.no_grad():
            encoded_input = {
                k: v.to(device) for k, v in encoded_input.items()
            }
            model_output = model(**encoded_input)

        return model_output.pooler_output.tolist()
    
    torch.set_num_threads(num_threads)

    # model and tokenizer could be replaced with
    # sentence-transformers/paraphrase-multilingual-mpnet-base-v2
    model_dir = '../model'
    tokenizer_dir = '../tokenizer'

    model = AutoModel.from_pretrained(model_dir)
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
    device = 'cpu'  # 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)

    one_hundred_token_text = 'The absence of or inadequate structural support within the eye (posterior capsular deformities, zonular laxity or dehiscence) would hinder the success of IOL implantation, as this may potentially lead to lens instability or dislocation/decentration (Source: Cataract Surgery Guidelines. The Royal College of Ophthalmologists, Regent Park, London. Sept. 2010).'

    dataset = Dataset.from_list([])
    for input in [one_hundred_token_text] * paragraph_count:
        dataset = dataset.add_item(
            {'text': input, 'request_idx': 0}
        )

    start = time_ns()
    predictions = dataset.map(
        lambda x: {
            'embeddings': _get_embeddings(x['text']),
            'request_idx': x['request_idx'],
        },
        batched=True,
        batch_size=batch_size,
        num_proc=num_proc
    )
    end = time_ns()

    row = {
        'batch_size': batch_size,
        'num_proc': num_proc,
        'elapsed': end - start,
        'paragraph_count': paragraph_count
    }
    
    return row

def get_metrics_num_proc_num_threads(num_procs: int, num_threads: int, paragraph_counts: list[int]) -> pd.DataFrame:
    disable_progress_bar()

    rows = []
    date = datetime.now().strftime('%Y%m%d')
    batch_size=64

    for paragraph_count in tqdm(paragraph_counts, desc='paragraph_counts'):
        row = get_metrics_num_proc_num_threads_row(
            paragraph_count=paragraph_count,
            batch_size=batch_size,
            num_proc=num_proc,
            num_threads=num_threads)
        rows.append(row)

    report_procs = num_proc
    report_batches = batch_size
    report_paragraphs = '_'.join(map(str, paragraph_counts))
    df = pd.DataFrame(rows)
    df['elapsed'] = df['elapsed'].apply(lambda x: x / (10**9))
    return df

metrics_hf = {}
for num_proc, num_thread in [
                             (1,1), (2,1), (3,1), (4,1), (5,1), (6,1), (7,1),   # work fine
                             (1,2), (1,3), (1,4), (1,5), (1,6), (1,7),   # work fine
                             (2, 2), (3,2), (4,2), (5,2), (6,2), (7,2),  # stuck
                            ]:
    metrics_hf[(num_proc, num_thread)] = get_metrics_num_proc_num_threads(
        num_procs=num_proc,
        num_threads=num_thread,
        paragraph_counts=[10, 25, 50, 100, 200, 400])

And stacktrace below (after terminating):

KeyboardInterrupt                         Traceback (most recent call last)
Cell In[83], line 2
      1 num_proc, num_thread = 2, 2
----> 2 metrics_hf[(num_proc, num_thread)] = get_metrics_num_proc_num_threads(
      3         num_procs=num_proc,
      4         num_threads=num_thread,
      5         paragraph_counts=[10, 25, 50, 100, 200, 400])

Cell In[62], line 9, in get_metrics_num_proc_num_threads(num_procs, num_threads, paragraph_counts)
      6 batch_size=64
      8 for paragraph_count in tqdm(paragraph_counts, desc='paragraph_counts'):
----> 9     row = get_metrics_num_proc_num_threads_row(
     10         paragraph_count=paragraph_count,
     11         batch_size=batch_size,
     12         num_proc=num_proc,
     13         num_threads=num_threads)
     14     rows.append(row)
     16 report_procs = num_proc

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/joblib/memory.py:594, in MemorizedFunc.__call__(self, *args, **kwargs)
    593 def __call__(self, *args, **kwargs):
--> 594     return self._cached_call(args, kwargs)[0]

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/joblib/memory.py:537, in MemorizedFunc._cached_call(self, args, kwargs, shelving)
    534         must_call = True
    536 if must_call:
--> 537     out, metadata = self.call(*args, **kwargs)
    538     if self.mmap_mode is not None:
    539         # Memmap the output at the first call to be consistent with
    540         # later calls
    541         if self._verbose:

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/joblib/memory.py:779, in MemorizedFunc.call(self, *args, **kwargs)
    777 if self._verbose > 0:
    778     print(format_call(self.func, args, kwargs))
--> 779 output = self.func(*args, **kwargs)
    780 self.store_backend.dump_item(
    781     [func_id, args_id], output, verbose=self._verbose)
    783 duration = time.time() - start_time

Cell In[61], line 39, in get_metrics_num_proc_num_threads_row(paragraph_count, batch_size, num_proc, num_threads)
     34     dataset = dataset.add_item(
     35         {'text': input, 'request_idx': 0}
     36     )
     38 start = time_ns()
---> 39 predictions = dataset.map(
     40     lambda x: {
     41         'embeddings': _get_embeddings(x['text']),
     42         'request_idx': x['request_idx'],
     43     },
     44     batched=True,
     45     batch_size=batch_size,
     46     num_proc=num_proc
     47 )
     48 end = time_ns()
     49 # result = handler.postprocess(inferenced)

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/arrow_dataset.py:563, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
    561     self: "Dataset" = kwargs.pop("self")
    562 # apply actual function
--> 563 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    564 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    565 for dataset in datasets:
    566     # Remove task templates if a column mapping of the template is no longer valid

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/arrow_dataset.py:528, in transmit_format.<locals>.wrapper(*args, **kwargs)
    521 self_format = {
    522     "type": self._format_type,
    523     "format_kwargs": self._format_kwargs,
    524     "columns": self._format_columns,
    525     "output_all_columns": self._output_all_columns,
    526 }
    527 # apply actual function
--> 528 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
    529 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
    530 # re-apply format to the output

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/arrow_dataset.py:3046, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
   3038 logger.info(f"Spawning {num_proc} processes")
   3039 with logging.tqdm(
   3040     disable=not logging.is_progress_bar_enabled(),
   3041     unit=" examples",
   (...)
   3044     desc=(desc or "Map") + f" (num_proc={num_proc})",
   3045 ) as pbar:
-> 3046     for rank, done, content in iflatmap_unordered(
   3047         pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
   3048     ):
   3049         if done:
   3050             shards_done += 1

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/utils/py_utils.py:1368, in iflatmap_unordered(pool, func, kwargs_iterable)
   1366 while True:
   1367     try:
-> 1368         yield queue.get(timeout=0.05)
   1369     except Empty:
   1370         if all(async_result.ready() for async_result in async_results) and queue.empty():

File <string>:2, in get(self, *args, **kwds)

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/managers.py:818, in BaseProxy._callmethod(self, methodname, args, kwds)
    815     conn = self._tls.connection
    817 conn.send((self._id, methodname, args, kwds))
--> 818 kind, result = conn.recv()
    820 if kind == '#RETURN':
    821     return result

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/connection.py:258, in _ConnectionBase.recv(self)
    256 self._check_closed()
    257 self._check_readable()
--> 258 buf = self._recv_bytes()
    259 return _ForkingPickler.loads(buf.getbuffer())

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/connection.py:422, in Connection._recv_bytes(self, maxsize)
    421 def _recv_bytes(self, maxsize=None):
--> 422     buf = self._recv(4)
    423     size, = struct.unpack("!i", buf.getvalue())
    424     if size == -1:

File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/connection.py:387, in Connection._recv(self, size, read)
    385 remaining = size
    386 while remaining > 0:
--> 387     chunk = read(handle, remaining)
    388     n = len(chunk)
    389     if n == 0:

Could you please explain me what I’m doing wrong? Is it prohibited to change thread count?

PS: I’ve experimented on t3a.2xlarge (8 vCPU, 32 Gb RAM) EC2 instance.

Based on Dataloader hangs. Potential deadlock with `set_num_threads` in worker processes? · Issue #75147 · pytorch/pytorch · GitHub, using torch.set_num_threads seems to lead to such problems if the multiprocess start method is fork. So maybe switching to spawn is the solution:

import multiprocess
multiprocess.set_start_method("spawn", force=True)