Hi,
I have a large dataset and wanna query values for a specific column for batch packing. It turns out that querying is super slow if I use datasets with indices mapping (e.g., after train_test_split
or select
operations).
The code below reproduces the issue:
import time
import numpy as np
from datasets import Dataset
input_ids = np.random.randint(0, 60_000, (500_000, 128)).tolist()
length = np.random.randint(3, 128, (500_000)).tolist()
dataset = Dataset.from_dict({"input_ids": input_ids, "length": length})
dataset_dict = dataset.train_test_split(test_size=0.1)
# ---------------------------------------------------------------
# Original dataset
start = time.time()
_ = dataset["length"]
print(f"Operation took {time.time() - start:.2f} seconds")
# Operation took 0.15 seconds
# ---------------------------------------------------------------
# ---------------------------------------------------------------
# Dataset with indices mapping
start = time.time()
_ = dataset_dict["train"]["length"]
print(f"Operation took {time.time() - start:.2f} seconds")
# Operation took 5.74 seconds
# ---------------------------------------------------------------
It takes forever to load values for my 500+ Gb dataset, and I have to use flatten_indices
on each dataset split to get deep copies with super fast querying performance. Anyway, flatten_indices
is no faster and also takes too much time. So preprocessing becomes super painful to achieve acceptable performance.
Is there a way to achieve acceptable performance without flattening indices?