Datasets: Limit the number of rows?

I’m trying to make sure my script I’m hacking works from end-to-end, and waiting for epochs to end in training just takes up a bunch of time. I’ve shortened down the number of epochs and batch size to 1, but I’m guessing the data that I’m using is just too large so it takes a long time to go through batches.

I’m using some code from the GLUE example and it does the following:

dataset = datasets.load_dataset("glue", task)

I’d like to have it only take a set number of samples so I can iterate quicker.

Some things I’ve tried:

dataset = datasets.load_dataset(glue, g_task, split=split)
dataset = dataset[:20]|

This complains with:

KeyError: "Invalid key: slice(None, 20, None). Please first select a split. For example: `my_dataset_dictionary['train'][slice(None, 20, None)]`. Available splits: ['test', 'train', 'validation']"

Fair, so it’s a dictionary. I then try this:

	dataset = datasets.load_dataset("glue", g_task, split=split)
	for k, v in dataset.items():
		dataset[k] = v[:20]

But then further on things blow up because I’m indexing unexpectedly:

Traceback (most recent call last):
  File "D:\dev\valve\source2\main\src\devtools\k8s\dota\toxic-chat-ml\test\run_text_classification.py", line 157, in <module>
    train_and_save()
  File "D:\dev\valve\source2\main\src\devtools\k8s\dota\toxic-chat-ml\test\run_text_classification.py", line 55, in train_and_save
    print(f"Sentence: {dataset['train'][0][sentence1_key]}")
KeyError: 0

Which makes sense - if I see the two versions of the dict before/after I slice it looks like I’m stomping some metadata with just the output of the training array.

I then see that there’s a “split” command I can issue to load_dataset that will let me do a slice, but it seems to only work if I request a specific ‘split’ (train/test) and won’t play nice with this dictionary based approach.

Am I missing an alternative here?

Thanks.

-e-

In case others can use the info, I worked around this by switching to my own dataset and just cutting down the amount of data in it. Crude, but effective.

If there is a way to do this via code however, that would be appreciated. It’d be nice not having to muck with my inputs when I just want to debug something.

Hi Eddie

The datasets documentation shows how to limit the datasets in various ways: Load

Hope that helps :slight_smile:

Cheers
Heiko

In case it helps others, this is the solution I went with:

		split='train'
		if is_dev_run():
			split = 'train[0:100]' # reduce the working size to speed up iteration

		dataset = datasets.load_dataset(input_files_directory, task, split=split)

		# 90% train, 10% test + validation
		train_testvalid = dataset.train_test_split(test_size=0.1)

		# Split the 10% test + valid in half test, half valid
		test_valid = train_testvalid['test'].train_test_split(test_size=0.5)

		# gather everyone if you want to have a single DatasetDict
		dataset = datasets.DatasetDict({'train': train_testvalid['train'], 'test': test_valid['test'], 'validation': test_valid['train']})
1 Like

You can use dataset.select(range(10)).