Local dataset loading performance: HF's arrow vs torch.load

@mariosasko Thanks for your help. Before I report the stats, her’s my code:

import cProfile, pstats
from datasets import load_from_disk


class MyDataset(Dataset):
    def __init__(self, parent_folder):
        self.parent_folder = parent_folder
        json_files = glob.glob(os.path.join(parent_folder, '**/*.json'),
                                            recursive=True)
        self.files = []
        for json_file in json_files:
            spec_file = json_file.replace('.json', '.spec')
            self.files.append((json_file, spec_file))

    def __getitem__(self, index):
        json_file, spec_file = self.files[index]
        with open(json_file, 'r') as fp:
            _json = json.load(fp)
        locale = os.path.relpath(os.path.dirname(json_file),
                                 self.parent_folder)
        id = os.path.splitext(os.path.basename(json_file))[0]
        _json["id"] = f"{locale}_{id}"
        _json["locale"] = locale
        _json["file"] = spec_file

        spec = torch.load(spec_file)
        return torch.squeeze(spec), _json

    def __len__(self):
        return len(self.data)


with cProfile.Profile() as profiler:
    ds = load_from_disk("./arrow")
    for i in range(1000):
        sample = ds[i]
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

with cProfile.Profile() as profiler:
    ds = MyDataset("./individual_files")
    for i in range(1000):
        sample = ds[i]
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

And here’s the stat repot:

Arrow

226695 function calls (201237 primitive calls) in 14.371 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.004    0.000   14.314    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
     1000    0.004    0.000   14.309    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
     1000    0.002    0.000   14.269    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
     1000    0.001    0.000   14.265    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
     1000    0.002    0.000   14.265    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:428(format_row)
     1000    0.003    0.000   14.244    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:143(extract_row)
     1000   14.238    0.014   14.238    0.014 {method 'to_pydict' of 'pyarrow.lib.Table' objects}
        1    0.000    0.000    0.057    0.057 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/load.py:1820(load_from_disk)
        1    0.000    0.000    0.057    0.057 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:1572(load_from_disk)
11319/473    0.012    0.000    0.035    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:128(deepcopy)
  437/272    0.000    0.000    0.033    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:227(_deepcopy_dict)
        1    0.000    0.000    0.033    0.033 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2601(with_format)
   3121/4    0.004    0.000    0.033    0.008 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:259(_reconstruct)
  536/137    0.001    0.000    0.032    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:201(_deepcopy_list)
    134/1    0.000    0.000    0.032    0.032 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/table.py:167(__deepcopy__)
    134/1    0.001    0.000    0.032    0.032 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/table.py:66(_deepcopy)
     1000    0.002    0.000    0.030    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:547(query_table)
10798/424    0.003    0.000    0.027    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:264(<genexpr>)

Torch

4183801 function calls (4183741 primitive calls) in 2.283 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.009    0.000    1.415    0.001 /tmp/ipykernel_56163/3477044752.py:14(__getitem__)
     1000    0.008    0.000    1.252    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:671(load)
     1000    0.009    0.000    0.886    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1104(_load)
     1000    0.007    0.000    0.868    0.001 {method 'load' of '_pickle.Unpickler' objects}
        1    0.044    0.044    0.860    0.860 /tmp/ipykernel_56163/3477044752.py:6(__init__)
     1000    0.002    0.000    0.826    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1125(persistent_load)
     1000    0.810    0.001    0.821    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1109(load_tensor)
        1    0.022    0.022    0.793    0.793 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:13(glob)
199849/199818    0.041    0.000    0.770    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:53(_iglob)
       30    0.000    0.000    0.376    0.013 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:93(_glob1)
       60    0.031    0.001    0.305    0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:162(_listdir)
   399752    0.259    0.000    0.274    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:128(_iterdir)
       30    0.075    0.002    0.218    0.007 /home/mehran/.conda/envs/whisper/lib/python3.10/fnmatch.py:54(filter)
   200877    0.135    0.000    0.207    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/posixpath.py:71(join)
     1000    0.003    0.000    0.152    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:66(_is_zipfile)
     4000    0.148    0.000    0.148    0.000 {method 'read' of '_io.BufferedReader' objects}
       31    0.000    0.000    0.147    0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:121(_glob2)
    59/30    0.000    0.000    0.147    0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:167(_rlistdir)

I had to eliminate the lower lines of both reports since the platform did not let me reply so many characters.

This does not show the 16x difference I talked about before, but I guess that was a slightly different scenario. For instance, I was accessing the records randomly then which might had some effect on the performance. Or I was reading more records then.

In any case, there’s a huge difference between the two. It would be great if some how I can improve the performance with arrow since it conserves a lot of space of my drive (individual_files: 184 GB vs arrow: 61 GB).