import time
import torch
import argparse
from datasets import load_dataset
from torch.utils.data import DataLoader
def test_dataset_load_speed():
"""Test the speed of loading a dataset using the datasets library"""
print("Testing datasets library loading speed...")
start_time = time.time()
dataset = load_dataset("./data/parquet", keep_in_memory=True)
load_time = time.time() - start_time
print(f"Time taken to load dataset: {load_time:.4f} seconds")
print(
f"Dataset size: train set {len(dataset['train'])} samples, validation set {len(dataset['validation'])} samples"
)
print(f"Feature columns: {dataset['train'].column_names}")
return dataset
def test_random_access_speed(dataset, num_samples=100):
"""Test the speed of random access to dataset items"""
print("\nTesting random access speed...")
train_dataset = dataset["train"]
total_items = len(train_dataset)
if total_items == 0:
print("Training dataset is empty, cannot test access speed")
return
# Adjust sample count to avoid exceeding dataset size
# num_samples = min(num_samples, total_items)
start_time = time.time()
for _ in range(num_samples):
idx = torch.randint(0, total_items, (1,)).item()
_ = train_dataset[idx]
access_time = time.time() - start_time
print(
f"Time taken to randomly access {num_samples} samples: {access_time:.4f} seconds"
)
print(f"Average access time per sample: {access_time / num_samples * 1000:.2f} ms")
def test_by_dataloader(dataset, batch_size, num_workers):
"""Test the speed of loading dataset using DataLoader"""
print(f"\nTesting data loading speed with batch size {batch_size}...")
dataloader = DataLoader(dataset["train"], batch_size=batch_size)
start_time = 0
for batch in dataloader:
if start_time == 0:
start_time = time.time()
# Simulate data processing
_ = batch
load_time = time.time() - start_time
print(f"Time taken to load batches of size {batch_size}: {load_time:.4f} seconds")
def main():
parser = argparse.ArgumentParser(
description="Test the reading speed of datasets library and data loader"
)
parser.add_argument(
"--random-samples",
type=int,
default=100,
help="Number of samples to test random access",
)
args = parser.parse_args()
print("Starting datasets library read speed test...\n")
# Test dataset loading speed
dataset = test_dataset_load_speed()
dataset.with_format("torch")
# Test random access speed
test_random_access_speed(dataset, args.random_samples)
# Test DataLoader loading speed
# test_by_dataloader(dataset, 1, 2)
print("\nTest completed!")
if __name__ == "__main__":
main()
And I got this result
➜ uv run .\test.py
Starting datasets library read speed test...
Testing datasets library loading speed...
Time taken to load dataset: 0.2677 seconds
Dataset size: train set 44 samples, validation set 2 samples
Feature columns: ['f0', 'volume', 'aug_vol', 'spk_id', 'frame_len', 'pitch_aug', 'mel', 'aug_mel', 'units']
Testing random access speed...
Time taken to randomly access 100 samples: 24.9198 seconds
Average access time per sample: 249.20 ms
Test completed!
The training .parquet file is only 175 MB, yet loading fragmented .npy files is significantly faster. I’m puzzled by this discrepancy and would appreciate any insights!
Certain data types are slower to load in pure python than others, like lists. If your dataset contains arrays or long lists, it’s faster to load them as numpy arrays using e.g.
ds = ds.with_format("numpy")
Btw you can also access multiple examples faster using a list of indices in ds[...]: