How to split main dataset into train, dev, test as DatasetDict

It seems that a single dataset can be split up into different partitions but in such a way that the connection between them is still clear (by using a DatasetDict), which is neat. I am having difficulties trying to figure out how I can create them, and use them, though. I’ve been going through the documentation [1],[2] and the source code [1],[2] but it hasn’t become any clearer.

In some parts you speak of only a train, test split other times you include validation. It is not clear how to split an existing Dataset into a DatasetDict containing train, dev (validation), and test keys, even though I think this is a common scenario.

The focus of the library seems to be on the use of existing datasets (which is a great feature!), but time permitting I would like to see more information and use cases for creating and managing your own datasets. I’m usually pretty good at figuring out a code base by its documentation and source code (e.g. transformers), but dataset seems to give me a harder time. That is not to say that I think it is bad, of course. But it does show, I think, that even though it seems simple from the examples, it is actually quite complex behind the scenes. For people who want do custom stuff it is then super important to have clear, verbose documentation about as many aspects of the library as possible.

1 Like

Hi Bram,
Yes the documentation of train_test_split that you link to is the right one. The train_test_split method currently provided is just a copy of the famous sklearn train_test_split (that we kinda assume people to be familiar with), we just removed the stratified split options which are quite complex.
We could add an option to split in three with a validation split indeed, feel free to open a PR on this if you would like to have this feature fast.
Right now what you can do is splitting two times:

# 90% train, 10% test + validation
train_testvalid = dataset.train_test_split(test=0.1)
# Split the 10% test + valid in half test, half valid
test_valid = train_test_dataset['test'].train_test_split(test=0.5)
# gather everyone if you want to have a single DatasetDict
train_test_valid_dataset = DatasetDict({
    'train': train_testvalid['train'],
    'test': test_valid['test'],
    'valid': test_valid['train']})

The mention of a validation split that you point to is just an enum provided for dataset creators who would like to include a standard name for a validation split in there dataset.

We will try to add more doc on the code organization but bear in mind that (1) this library is still very young and (2) we are a lot less working on it (it’s really mostly one person, Quentin, that I try to help as much as I can), so it will definitely take some time before we have “have clear, verbose documentation about as many aspects of the library as possible”.

Basically, to give you an idea, the code is organized in two main parts:

  1. the dataset building part which is defined in part by the people writing datasets and very open (hence the many options for splits in this part) => this is most of the complex code because it a wrapper around script provided externally. This includes files like builder.py, load.py, arrow_dataset.py.
  2. the dataset processing part (after the dataset has been build) which is mostly contained in the arrow_dataset.py file and contains most of what the users will actually interact with => this is probably the part you need to read the most. The main complex part here is that we are deeply integrated with Apache Arrow which is very efficient but definitely not the easiest framework to understand.

You can also read this part in the doc where I tried to make a graph and give some information on how datasets are created (the first part in my list above): https://huggingface.co/docs/datasets/add_dataset.html

1 Like

I tried your approach, but I end up with empty batches here. Creating a dataloader for the whole dataset works:

dataloaders = {"train": DataLoader(dataset, batch_size=8)}

for batch in dataloaders["train"]:
    print(batch.keys())
    # prints the expected keys

But when I split the dataset as you suggest, I run into issues; the batches are empty.

# dataset is already `map`'d and already has `set_format`
# 90% train, 10% test + validation
train_testvalid = dataset.train_test_split(test_size=0.1)
# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid["test"].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
datasets = DatasetDict({
    "train": train_testvalid["train"],
    "test": test_valid["test"],
    "valid": test_valid["train"]})

dataloaders = {partition: DataLoader(ds, batch_size=8) for partition, ds in datasets.items()}

for batch in dataloaders["train"]:
    print(batch.keys())
    # dict_keys([])

Even when I just split into train and test, the batches are empty.

train_test = dataset.train_test_split(test_size=0.1)
print(next(iter(train_test["train"])))  # empty

During those two last snippets, I see a lot of warnings/loggings. Using a dummy dataset of 100 entries, I see this:

PyTorch version 1.6.0+cu101 available.
Testing the mapped function outputs
Testing finished, running the mapping function on the dataset
100%|██████████| 100/100 [00:00<00:00, 4999.95ex/s]
Done writing 100 examples in 426851 bytes .
Testing the mapped function outputs
Testing finished, running the mapping function on the dataset
100%|██████████| 1/1 [00:00<00:00, 15.38ba/s]
Done writing 100 examples in 1656863 bytes .
Set __getitem__(key) output type to torch for ['input_ids', 'sembedding'] columns  (when key is int or slice) and don't output other (un-formatted) columns.
Done writing 90 indices in 720 bytes .
Set __getitem__(key) output type to torch for ['input_ids', 'sembedding'] columns  (when key is int or slice) and don't output other (un-formatted) columns.
Done writing 10 indices in 80 bytes .
Set __getitem__(key) output type to torch for ['input_ids', 'sembedding'] columns  (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to torch for [] columns  (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to torch for [] columns  (when key is int or slice) and don't output other (un-formatted) columns.

It seems that for each split, set_format is called again, but after that it is called again with empty column names - meaning that no columns will be included in the end. I am not sure why it is calling it with empty columns there. I thought that this might be a bug with the fingerprinting, but after clearing the cache this still occurs.

The solution seems to be to first doing the partitioning and then doing the map and set_format on the DatasetDict. That being said, I still feel like the behaviour that I mentioned before is a bug - but I am not sure.

Thanks for the feedback @BramVanroy

There are a few things here indeed and I think the first is that we should rationalize more the logging of the library.

For instance these Set __getitem__(key) output type to torch logs are too much information for the users, in particular since they are actually only used internally by the library in these calls.

Now the fact that you have to use the method in a specific order seems to be a bug, if you want to open an issue about we would be happy to investigate.