Remove a row/specific index from the dataset

Given the code

from datasets import load_dataset

dataset = load_dataset("glue", "mrpc", split='train')
idx = 0

How can I remove row 0 (dataset[0]) from this dataset?

The only way I can think of for now is using dataset.select(), and then selecting every index except 0, but that doesn’t seem efficient.

Hi!

You can do dataset = load_dataset("glue", "mrpc", split='train[1:]') to skip the first example while loading the dataset.

The only way I can think of for now is using dataset.select(), and then selecting every index except 0, but that doesn’t seem efficient.

Why do you think select is not efficient? It depends on the ops you use afterward, but select alone is very efficient as it only creates an indices mapping, which is (almost) equal to list(indices), and not a new PyArrow table.
`

Hi Thank you for your reply.

The issue is that I need to remove random rows from the dataset. So not just idx = 0. But more like idxs =[ 76, 3, 384,10]. Currently I do this by selecting every index that is not in idxs. Which works, but I feel like there should be a better way to do it.

Which works, but I feel like there should be a better way to do it.

“Better way” in terms of the API design? If yes, do you have an API in mind? Or better in terms of speed?
Removing rows is not easy to implement (efficiently) because PyArorw tables, which datasets use to store data, are immutable. You could use pandas for that (ds.to_pandas()) if your dataset is not too big and can fit in memory.

In summary, it seems the current solution is to select all of the ids except the ones you don’t want.

So in this example, something like:

from datasets import load_dataset

# load dataset
dataset = load_dataset("glue", "mrpc", split='train')

# what we don't want
exclude_idx = [76, 3, 384, 10]

# create new dataset exluding those idx
dataset = dataset.select(
    (
        i for i in range(len(dataset)) 
        if i not in set(exclude_idx)
    )
)
3 Likes