Initializing splits from existing Dataset objects


I do have a dataset that looks like this:

It consists of a single “train” split containing 2 columns, “image” and “label”. Each image has one of 4 possible labels: trainA, trainB, testA, testB.

What I’d like to do is turn this dataset into one that has 2 splits (“train” and “test”), each having 2 columns (“imageA” and “imageB”). How can I do this?

I can filter on them:

from datasets import load_dataset

dataset = load_dataset("huggan/grumpifycat")

train_examples_A = dataset.filter(lambda example: example['label'] == 2)['train']
train_examples_B = dataset.filter(lambda example: example['label'] == 3)['train']
test_examples_A = dataset.filter(lambda example: example['label'] == 0)['train']
test_examples_B = dataset.filter(lambda example: example['label'] == 1)['train']

Now I have 4 Dataset objects. Can I use those to instantiate a new DatasetDict with these objects?

Hi ! A DatasetDict is simply a python dictionary of datasets:

train_dataset = concatenate_datasets([
    train_examples_A.rename_column("image", "imageA").remove_columns(["label"]),
    train_examples_B.rename_column("image", "imageB").remove_columns(["label"])
], axis=1)
test_dataset = concatenate_datasets([
    test_examples_A.rename_column("image", "imageA").remove_columns(["label"]),
    test_examples_B.rename_column("image", "imageB").remove_columns(["label"])
], axis=1)

datasets = DatasetDict({
    "train": train_dataset,
    "test": test_dataset

You might need to make sure that train_examples_A and train_examples_B have the same length before concatenating them