Local dataset loading performance: HF's arrow vs torch.load

I’ve made a local dataset by transforming Common Voice’s audio files into spectrograms. As such, each sample holds a tensor of (513, 128) shape (along some tabular features). Then, I’ve saved this into two formats.

First, I used tensor.save (along a JSON file for the other features) to save them into hard drive. This way each sample is composed of two files.

Second, I used HF’s datasets package to save the same dataset into arrow format:

def arrow_generator(data_samples):
    def _gen():
        for rec in data_samples:
            yield {
                "id": rec["id"],
                "client_id": rec["client_id"],
                "locale": rec["locale"],
                "spectrogram": rec["spectrogram"], # PyTorch tensor
            }

    return _gen

ds = Dataset.from_generator(arrow_generator(data_samples))
ds.save_to_disk("./arrow_dataset")

To compare the two, the HF’s arrow format takes half the size on the disk. But at the same time, it takes 16x longer to read a batch (load_from_disk is 16 times slower than the torch.load).

What I wanted to ask is if this is expected or am I doing some thing wrong?

Hi! torch.save uses pickle under the hood to serialize objects, and pickle is slower than Feather, a format very similar to ours (we plan to switch to Feather eventually), according to this blog post.

Do you mind profiling the load_from_disk call using the code below?

import cProfile, pstats
from datasets import load_from_disk

with cProfile.Profile() as profiler:
    ds = load_from_disk(...)

stats = pstats.Stats(profiler).sort_stats("cumtime")
stats.print_stats()

So that we can be sure this is an Arrow issue.

1 Like

@mariosasko Thanks for your help. Before I report the stats, her’s my code:

import cProfile, pstats
from datasets import load_from_disk


class MyDataset(Dataset):
    def __init__(self, parent_folder):
        self.parent_folder = parent_folder
        json_files = glob.glob(os.path.join(parent_folder, '**/*.json'),
                                            recursive=True)
        self.files = []
        for json_file in json_files:
            spec_file = json_file.replace('.json', '.spec')
            self.files.append((json_file, spec_file))

    def __getitem__(self, index):
        json_file, spec_file = self.files[index]
        with open(json_file, 'r') as fp:
            _json = json.load(fp)
        locale = os.path.relpath(os.path.dirname(json_file),
                                 self.parent_folder)
        id = os.path.splitext(os.path.basename(json_file))[0]
        _json["id"] = f"{locale}_{id}"
        _json["locale"] = locale
        _json["file"] = spec_file

        spec = torch.load(spec_file)
        return torch.squeeze(spec), _json

    def __len__(self):
        return len(self.data)


with cProfile.Profile() as profiler:
    ds = load_from_disk("./arrow")
    for i in range(1000):
        sample = ds[i]
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

with cProfile.Profile() as profiler:
    ds = MyDataset("./individual_files")
    for i in range(1000):
        sample = ds[i]
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

And here’s the stat repot:

Arrow

226695 function calls (201237 primitive calls) in 14.371 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.004    0.000   14.314    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
     1000    0.004    0.000   14.309    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
     1000    0.002    0.000   14.269    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
     1000    0.001    0.000   14.265    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
     1000    0.002    0.000   14.265    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:428(format_row)
     1000    0.003    0.000   14.244    0.014 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:143(extract_row)
     1000   14.238    0.014   14.238    0.014 {method 'to_pydict' of 'pyarrow.lib.Table' objects}
        1    0.000    0.000    0.057    0.057 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/load.py:1820(load_from_disk)
        1    0.000    0.000    0.057    0.057 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:1572(load_from_disk)
11319/473    0.012    0.000    0.035    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:128(deepcopy)
  437/272    0.000    0.000    0.033    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:227(_deepcopy_dict)
        1    0.000    0.000    0.033    0.033 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2601(with_format)
   3121/4    0.004    0.000    0.033    0.008 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:259(_reconstruct)
  536/137    0.001    0.000    0.032    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:201(_deepcopy_list)
    134/1    0.000    0.000    0.032    0.032 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/table.py:167(__deepcopy__)
    134/1    0.001    0.000    0.032    0.032 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/table.py:66(_deepcopy)
     1000    0.002    0.000    0.030    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:547(query_table)
10798/424    0.003    0.000    0.027    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/copy.py:264(<genexpr>)

Torch

4183801 function calls (4183741 primitive calls) in 2.283 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.009    0.000    1.415    0.001 /tmp/ipykernel_56163/3477044752.py:14(__getitem__)
     1000    0.008    0.000    1.252    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:671(load)
     1000    0.009    0.000    0.886    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1104(_load)
     1000    0.007    0.000    0.868    0.001 {method 'load' of '_pickle.Unpickler' objects}
        1    0.044    0.044    0.860    0.860 /tmp/ipykernel_56163/3477044752.py:6(__init__)
     1000    0.002    0.000    0.826    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1125(persistent_load)
     1000    0.810    0.001    0.821    0.001 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:1109(load_tensor)
        1    0.022    0.022    0.793    0.793 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:13(glob)
199849/199818    0.041    0.000    0.770    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:53(_iglob)
       30    0.000    0.000    0.376    0.013 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:93(_glob1)
       60    0.031    0.001    0.305    0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:162(_listdir)
   399752    0.259    0.000    0.274    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:128(_iterdir)
       30    0.075    0.002    0.218    0.007 /home/mehran/.conda/envs/whisper/lib/python3.10/fnmatch.py:54(filter)
   200877    0.135    0.000    0.207    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/posixpath.py:71(join)
     1000    0.003    0.000    0.152    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/torch/serialization.py:66(_is_zipfile)
     4000    0.148    0.000    0.148    0.000 {method 'read' of '_io.BufferedReader' objects}
       31    0.000    0.000    0.147    0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:121(_glob2)
    59/30    0.000    0.000    0.147    0.005 /home/mehran/.conda/envs/whisper/lib/python3.10/glob.py:167(_rlistdir)

I had to eliminate the lower lines of both reports since the platform did not let me reply so many characters.

This does not show the 16x difference I talked about before, but I guess that was a slightly different scenario. For instance, I was accessing the records randomly then which might had some effect on the performance. Or I was reading more records then.

In any case, there’s a huge difference between the two. It would be great if some how I can improve the performance with arrow since it conserves a lot of space of my drive (individual_files: 184 GB vs arrow: 61 GB).

Reading some tutorials, I learned that I could be converting the read dataset into PyTorch tensors directly using load_from_disk(...).with_format("torch") function.

with cProfile.Profile() as profiler:
    ds = load_from_disk("./arrow").with_format("torch")
    for i in range(1000):
        sample = ds[i]
    stats = pstats.Stats(profiler).sort_stats("cumtime")
    stats.print_stats()

This simple change improved the performance a lot:

269871 function calls (6612931 primitive calls) in 3.085 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1000    0.004    0.000    2.920    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2776(__getitem__)
     1000    0.004    0.000    2.916    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/arrow_dataset.py:2750(_getitem)
     1000    0.002    0.000    2.881    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:587(format_table)
     1000    0.007    0.000    2.876    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/formatting.py:394(__call__)
     1000    0.003    0.000    2.869    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:80(format_row)
203000/1000    0.075    0.000    2.697    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:77(recursive_tensorize)
203000/1000    0.157    0.000    2.696    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:372(map_nested)
     1000    0.005    0.000    2.644    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:443(<listcomp>)
     7000    0.004    0.000    2.637    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/utils/py_utils.py:340(_single_map_nested)
209000/7000    0.161    0.000    2.631    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:70(_recursive_tensorize)
2000/1000    0.051    0.000    2.556    0.003 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:74(<listcomp>)
   207000    0.391    0.000    1.922    0.000 /home/mehran/.conda/envs/whisper/lib/python3.10/site-packages/datasets/formatting/torch_formatter.py:49(_tensorize)
   204000    0.830    0.000    0.830    0.000 {built-in method torch.tensor}
   607000    0.269    0.000    0.642    0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:356(issubdtype)
  1214000    0.232    0.000    0.339    0.000 /home/mehran/.local/lib/python3.10/site-packages/numpy/core/numerictypes.py:282(issubclass_)

I’m not sure if there’s any more room left to improve the performance, but this alone solves my problem.

Thanks, @mariosasko , I could not have done it without you.

You should get better performance by using

for sample in iter(ds):
    ...

instead of

for i in range(1000):
    sample = ds[i]

Also, most of the time is spent in PyArrow’s to_pydict (Arrow objects need to be converted to the Python representation), and converting to Python is not optimized for most Arrow types, so we can’t do much about this.

1 Like