For a few days I’m trying to figure out how I can speedup inference. I stucked with num_proc
Dataset.map
’s parameter. Also I found that PyTorch has torch.set_num_threads(int)
method. I’ve tried different combinations num_proc
and torch.set_num_threads
and found an issue with that: everything works fine with threads = 1 and num_proc
equal 1 or 2. If I’m trying to change num_proc
to 2, 3, … and set the threads count to 2 then Dataset.map
stucks. I’ve waited for a hour on a really small dataset without any successs.
The code to reproduce:
from transformers import AutoModel, AutoTokenizer
from datasets import Dataset
from datasets.utils.logging import disable_progress_bar
import torch
def get_metrics_num_proc_num_threads_row(paragraph_count, batch_size, num_proc, num_threads):
def _get_embeddings(texts):
encoded_input = tokenizer(
texts, padding=True, truncation=True, return_tensors='pt'
)
with torch.no_grad():
encoded_input = {
k: v.to(device) for k, v in encoded_input.items()
}
model_output = model(**encoded_input)
return model_output.pooler_output.tolist()
torch.set_num_threads(num_threads)
# model and tokenizer could be replaced with
# sentence-transformers/paraphrase-multilingual-mpnet-base-v2
model_dir = '../model'
tokenizer_dir = '../tokenizer'
model = AutoModel.from_pretrained(model_dir)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir)
device = 'cpu' # 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
one_hundred_token_text = 'The absence of or inadequate structural support within the eye (posterior capsular deformities, zonular laxity or dehiscence) would hinder the success of IOL implantation, as this may potentially lead to lens instability or dislocation/decentration (Source: Cataract Surgery Guidelines. The Royal College of Ophthalmologists, Regent Park, London. Sept. 2010).'
dataset = Dataset.from_list([])
for input in [one_hundred_token_text] * paragraph_count:
dataset = dataset.add_item(
{'text': input, 'request_idx': 0}
)
start = time_ns()
predictions = dataset.map(
lambda x: {
'embeddings': _get_embeddings(x['text']),
'request_idx': x['request_idx'],
},
batched=True,
batch_size=batch_size,
num_proc=num_proc
)
end = time_ns()
row = {
'batch_size': batch_size,
'num_proc': num_proc,
'elapsed': end - start,
'paragraph_count': paragraph_count
}
return row
def get_metrics_num_proc_num_threads(num_procs: int, num_threads: int, paragraph_counts: list[int]) -> pd.DataFrame:
disable_progress_bar()
rows = []
date = datetime.now().strftime('%Y%m%d')
batch_size=64
for paragraph_count in tqdm(paragraph_counts, desc='paragraph_counts'):
row = get_metrics_num_proc_num_threads_row(
paragraph_count=paragraph_count,
batch_size=batch_size,
num_proc=num_proc,
num_threads=num_threads)
rows.append(row)
report_procs = num_proc
report_batches = batch_size
report_paragraphs = '_'.join(map(str, paragraph_counts))
df = pd.DataFrame(rows)
df['elapsed'] = df['elapsed'].apply(lambda x: x / (10**9))
return df
metrics_hf = {}
for num_proc, num_thread in [
(1,1), (2,1), (3,1), (4,1), (5,1), (6,1), (7,1), # work fine
(1,2), (1,3), (1,4), (1,5), (1,6), (1,7), # work fine
(2, 2), (3,2), (4,2), (5,2), (6,2), (7,2), # stuck
]:
metrics_hf[(num_proc, num_thread)] = get_metrics_num_proc_num_threads(
num_procs=num_proc,
num_threads=num_thread,
paragraph_counts=[10, 25, 50, 100, 200, 400])
And stacktrace below (after terminating):
KeyboardInterrupt Traceback (most recent call last)
Cell In[83], line 2
1 num_proc, num_thread = 2, 2
----> 2 metrics_hf[(num_proc, num_thread)] = get_metrics_num_proc_num_threads(
3 num_procs=num_proc,
4 num_threads=num_thread,
5 paragraph_counts=[10, 25, 50, 100, 200, 400])
Cell In[62], line 9, in get_metrics_num_proc_num_threads(num_procs, num_threads, paragraph_counts)
6 batch_size=64
8 for paragraph_count in tqdm(paragraph_counts, desc='paragraph_counts'):
----> 9 row = get_metrics_num_proc_num_threads_row(
10 paragraph_count=paragraph_count,
11 batch_size=batch_size,
12 num_proc=num_proc,
13 num_threads=num_threads)
14 rows.append(row)
16 report_procs = num_proc
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/joblib/memory.py:594, in MemorizedFunc.__call__(self, *args, **kwargs)
593 def __call__(self, *args, **kwargs):
--> 594 return self._cached_call(args, kwargs)[0]
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/joblib/memory.py:537, in MemorizedFunc._cached_call(self, args, kwargs, shelving)
534 must_call = True
536 if must_call:
--> 537 out, metadata = self.call(*args, **kwargs)
538 if self.mmap_mode is not None:
539 # Memmap the output at the first call to be consistent with
540 # later calls
541 if self._verbose:
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/joblib/memory.py:779, in MemorizedFunc.call(self, *args, **kwargs)
777 if self._verbose > 0:
778 print(format_call(self.func, args, kwargs))
--> 779 output = self.func(*args, **kwargs)
780 self.store_backend.dump_item(
781 [func_id, args_id], output, verbose=self._verbose)
783 duration = time.time() - start_time
Cell In[61], line 39, in get_metrics_num_proc_num_threads_row(paragraph_count, batch_size, num_proc, num_threads)
34 dataset = dataset.add_item(
35 {'text': input, 'request_idx': 0}
36 )
38 start = time_ns()
---> 39 predictions = dataset.map(
40 lambda x: {
41 'embeddings': _get_embeddings(x['text']),
42 'request_idx': x['request_idx'],
43 },
44 batched=True,
45 batch_size=batch_size,
46 num_proc=num_proc
47 )
48 end = time_ns()
49 # result = handler.postprocess(inferenced)
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/arrow_dataset.py:563, in transmit_tasks.<locals>.wrapper(*args, **kwargs)
561 self: "Dataset" = kwargs.pop("self")
562 # apply actual function
--> 563 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
564 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
565 for dataset in datasets:
566 # Remove task templates if a column mapping of the template is no longer valid
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/arrow_dataset.py:528, in transmit_format.<locals>.wrapper(*args, **kwargs)
521 self_format = {
522 "type": self._format_type,
523 "format_kwargs": self._format_kwargs,
524 "columns": self._format_columns,
525 "output_all_columns": self._output_all_columns,
526 }
527 # apply actual function
--> 528 out: Union["Dataset", "DatasetDict"] = func(self, *args, **kwargs)
529 datasets: List["Dataset"] = list(out.values()) if isinstance(out, dict) else [out]
530 # re-apply format to the output
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/arrow_dataset.py:3046, in Dataset.map(self, function, with_indices, with_rank, input_columns, batched, batch_size, drop_last_batch, remove_columns, keep_in_memory, load_from_cache_file, cache_file_name, writer_batch_size, features, disable_nullable, fn_kwargs, num_proc, suffix_template, new_fingerprint, desc)
3038 logger.info(f"Spawning {num_proc} processes")
3039 with logging.tqdm(
3040 disable=not logging.is_progress_bar_enabled(),
3041 unit=" examples",
(...)
3044 desc=(desc or "Map") + f" (num_proc={num_proc})",
3045 ) as pbar:
-> 3046 for rank, done, content in iflatmap_unordered(
3047 pool, Dataset._map_single, kwargs_iterable=kwargs_per_job
3048 ):
3049 if done:
3050 shards_done += 1
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/datasets/utils/py_utils.py:1368, in iflatmap_unordered(pool, func, kwargs_iterable)
1366 while True:
1367 try:
-> 1368 yield queue.get(timeout=0.05)
1369 except Empty:
1370 if all(async_result.ready() for async_result in async_results) and queue.empty():
File <string>:2, in get(self, *args, **kwds)
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/managers.py:818, in BaseProxy._callmethod(self, methodname, args, kwds)
815 conn = self._tls.connection
817 conn.send((self._id, methodname, args, kwds))
--> 818 kind, result = conn.recv()
820 if kind == '#RETURN':
821 return result
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/connection.py:258, in _ConnectionBase.recv(self)
256 self._check_closed()
257 self._check_readable()
--> 258 buf = self._recv_bytes()
259 return _ForkingPickler.loads(buf.getbuffer())
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/connection.py:422, in Connection._recv_bytes(self, maxsize)
421 def _recv_bytes(self, maxsize=None):
--> 422 buf = self._recv(4)
423 size, = struct.unpack("!i", buf.getvalue())
424 if size == -1:
File ~/git/ml/document-model-pytorch/envs/lib/python3.11/site-packages/multiprocess/connection.py:387, in Connection._recv(self, size, read)
385 remaining = size
386 while remaining > 0:
--> 387 chunk = read(handle, remaining)
388 n = len(chunk)
389 if n == 0:
Could you please explain me what I’m doing wrong? Is it prohibited to change thread count?
PS: I’ve experimented on t3a.2xlarge
(8 vCPU, 32 Gb RAM) EC2 instance.