Confusion in splitting dataset (from imagefolder) into train, test and validation


I am trying to load up images from dataset with the following structure for fine-tuning the vision transformer model. My dataset has following structure:

— ClassA (x images)
----ClassB (y images)
----ClassC (z images)

I am quite confused on how to split the dataset into train, test and validation. I read various similar questions but couldn’t understand the process clearly. I have tried to do the following (train-test-validation split in a non-random manner, though actually I would like it to be randomly splitted):

from datasets import load_dataset

ds = load_dataset("imagefolder",data_dir="/Documents/DataSetFolder/",split="test")
# split up data into train + test

splits = ds.train_test_split(test_size=0.3)

train_ds = splits['train']

test_ds = splits['test']
# split up data into val + test

splits = ds.train_test_split(test_size=0.15)

test_ds = splits['test']

val_ds = splits['test']

Is this a correct process for randomly splitting the dataset into 70% training, 15% test and 15% validation? Also, how I can do a random train-test split and what is the significance of the split argument in load_dataset? Would really appreciate some guidance on this as I am very confused even after reading the documentation at

Thanks very much!

Hi! Your code produces (potentially overlapping) splits of incorrect size. This is the fixed code:

from datasets import load_dataset

ds = load_dataset("imagefolder", data_dir="/Documents/DataSetFolder/", split="test")

ds_split_train_test = ds.train_test_split(test_size=0.15)

train_ds, test_ds = ds_split_train_test["train"], ds_split_train_test["test"]

ds_split_train_val = train_ds.split_train_test(test_size=0.15/0.85)

train_ds, val_ds = ds_split_train_test["train"], ds_split_train_test["test"]

what is the significance of the split argument in load_dataset ?

If specified, this argument returns a concrete dataset split/subset instead of returning a dictionary with all the subsets. You can think of it as being equal to load_dataset("imagefolder", data_dir="/Documents/DataSetFolder/")[split]. Note that this arg also supports the slicing syntax, but you shouldn’t use it here as this doesn’t shuffle the data.

1 Like

@mariosasko I have a slight variation of the question here. Maybe you can help me?