When loading a dataset and saving via save_to_disk
like this:
from datasets import load_dataset
# Load the dataset
dataset = load_dataset(input_dir)
# Filter the dataset
user_wins_dataset = dataset["train"].filter(lambda x: x["winner"] == "user")
# Save the filtered dataset... saves the filtered-out rows, too :(
user_wins_dataset.save_to_disk(output_dir)
The written pyarrow files contain the entire dataset, even all the rows I filtered out. This is severely slowing down my data pipeline because the dataset is massive. Why is datasets
not saving only the filtered rows, as I would expect, and how can I make it do that?
When I load the saved dataset, the len
is correct (24680), much reduced, but I notice the dataset_info.json
file contains the original size (28608139).
flatten_indices
doesn’t appear to help:
d2 = d.flatten_indices(keep_in_memory=True)
d2.save_to_disk("flattened") # Same size, all rows saved :(