Local dataset loading performance: HF's arrow vs torch.load

Reading some tutorials, I learned that I could be converting the read dataset into PyTorch tensors directly using load_from_disk(...).with_format("torch") function.

with cProfile.Profile() as profiler:
    ds = load_from_disk("./arrow").with_format("torch")
    for i in range(1000):
        sample = ds[i]
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

This simple change improved the performance a lot:

269871 function calls (6612931 primitive calls) in 3.085 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.004    0.000    2.920    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
     1000    0.004    0.000    2.916    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
     1000    0.002    0.000    2.881    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
     1000    0.007    0.000    2.876    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
     1000    0.003    0.000    2.869    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:80(format_row)
203000/1000    0.075    0.000    2.697    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:77(recursive_tensorize)
203000/1000    0.157    0.000    2.696    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:372(map_nested)
     1000    0.005    0.000    2.644    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:443(<listcomp>)
     7000    0.004    0.000    2.637    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:340(_single_map_nested)
209000/7000    0.161    0.000    2.631    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:70(_recursive_tensorize)
2000/1000    0.051    0.000    2.556    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:74(<listcomp>)
   207000    0.391    0.000    1.922    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:49(_tensorize)
   204000    0.830    0.000    0.830    0.000 {built-in method torch.tensor}
   607000    0.269    0.000    0.642    0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:356(issubdtype)
  1214000    0.232    0.000    0.339    0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:282(issubclass_)

I’m not sure if there’s any more room left to improve the performance, but this alone solves my problem.

Thanks, @mariosasko , I could not have done it without you.