Filtering performance

Hi All,

I’m trying to filter a dataset that has about 1 million rows and about a dozen columns (some of those are Array2D, so the arrow files is ~100GB, but I don’t think that’s very relevant).
The column I’m filtering on is made up of 100-300 character long strings, and I have a set of strings that I want to exclude.

Each of the follwing takes about 1min 30s to execute:

a = dataset.filter(lambda x: not x in exclude, input_columns=['column'])
a = dataset.filter(lambda x: not x in exclude, input_columns=['column'],  batch_size=None)
a = dataset.filter(lambda xs: [x not in exclude for x in xs], input_columns=['column'], batched=True)
a = dataset.filter(lambda xs: [x not in exclude for x in xs], input_columns=['column'], batch_size=None, batched=True)

But this only takes 1.85s:

a = dataset['column']
b = [i for i, x in enumerate(a) if x not in exclude]
c = dataset.select(b)

Is this expected, or am I doing something wrong with the dataset?
(The other option solves the problem for me, I’m more interested as to why such a difference)

2 Likes

Hi! Which version of datasets are you using? We’ve made some improvements in the latest release (2.8.0) to optimize decoding, so use this version for the best performance.

Also, unlike select (creates an indices mapping), filter writes a new dataset to disk/memory, which can take some time for larger datasets (some benefits are faster indexing, etc.)

I’m on 2.7.1, I’ll try upgrading. But writing the new dataset to probably explains the difference, as there is close to 100GB to be written out, eve after the filtering

That’s not true - filter does add an indices mapping. We should update the docstring

@lhoestq can I confirm that only the valid (according to filter criteria) indices are cached by filter() and it does not actually create a cached copy of the dataset with valid entries?

1 Like

Yes correct, filter() only stores the indices to save disk space.

For people who want to rewrite the dataset completely (e.g. to end up with contiguous data and get faster reads), there is ds.flatten_indices() that rewrites the dataset and removes the indices mapping

1 Like