Reading some tutorials, I learned that I could be converting the read dataset into PyTorch tensors directly using load_from_disk(...).with_format("torch")
function.
with cProfile.Profile() as profiler:
ds = load_from_disk("./arrow").with_format("torch")
for i in range(1000):
sample = ds[i]
stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()
This simple change improved the performance a lot:
269871 function calls (6612931 primitive calls) in 3.085 seconds
Ordered by: cumulative time
ncalls tottime percall cumtime percall filename:lineno(function)
1000 0.004 0.000 2.920 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
1000 0.004 0.000 2.916 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
1000 0.002 0.000 2.881 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
1000 0.007 0.000 2.876 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
1000 0.003 0.000 2.869 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:80(format_row)
203000/1000 0.075 0.000 2.697 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:77(recursive_tensorize)
203000/1000 0.157 0.000 2.696 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:372(map_nested)
1000 0.005 0.000 2.644 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:443(<listcomp>)
7000 0.004 0.000 2.637 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:340(_single_map_nested)
209000/7000 0.161 0.000 2.631 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:70(_recursive_tensorize)
2000/1000 0.051 0.000 2.556 0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:74(<listcomp>)
207000 0.391 0.000 1.922 0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:49(_tensorize)
204000 0.830 0.000 0.830 0.000 {built-in method torch.tensor}
607000 0.269 0.000 0.642 0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:356(issubdtype)
1214000 0.232 0.000 0.339 0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:282(issubclass_)
I’m not sure if there’s any more room left to improve the performance, but this alone solves my problem.
Thanks, @mariosasko , I could not have done it without you.