I tried your approach, but I end up with empty batches here. Creating a dataloader for the whole dataset works:
dataloaders = {"train": DataLoader(dataset, batch_size=8)}
for batch in dataloaders["train"]:
print(batch.keys())
# prints the expected keys
But when I split the dataset as you suggest, I run into issues; the batches are empty.
# dataset is already `map`'d and already has `set_format`
# 90% train, 10% test + validation
train_testvalid = dataset.train_test_split(test_size=0.1)
# Split the 10% test + valid in half test, half valid
test_valid = train_testvalid["test"].train_test_split(test_size=0.5)
# gather everyone if you want to have a single DatasetDict
datasets = DatasetDict({
"train": train_testvalid["train"],
"test": test_valid["test"],
"valid": test_valid["train"]})
dataloaders = {partition: DataLoader(ds, batch_size=8) for partition, ds in datasets.items()}
for batch in dataloaders["train"]:
print(batch.keys())
# dict_keys([])
Even when I just split into train and test, the batches are empty.
train_test = dataset.train_test_split(test_size=0.1)
print(next(iter(train_test["train"]))) # empty
During those two last snippets, I see a lot of warnings/loggings. Using a dummy dataset of 100 entries, I see this:
PyTorch version 1.6.0+cu101 available.
Testing the mapped function outputs
Testing finished, running the mapping function on the dataset
100%|āāāāāāāāāā| 100/100 [00:00<00:00, 4999.95ex/s]
Done writing 100 examples in 426851 bytes .
Testing the mapped function outputs
Testing finished, running the mapping function on the dataset
100%|āāāāāāāāāā| 1/1 [00:00<00:00, 15.38ba/s]
Done writing 100 examples in 1656863 bytes .
Set __getitem__(key) output type to torch for ['input_ids', 'sembedding'] columns (when key is int or slice) and don't output other (un-formatted) columns.
Done writing 90 indices in 720 bytes .
Set __getitem__(key) output type to torch for ['input_ids', 'sembedding'] columns (when key is int or slice) and don't output other (un-formatted) columns.
Done writing 10 indices in 80 bytes .
Set __getitem__(key) output type to torch for ['input_ids', 'sembedding'] columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to torch for [] columns (when key is int or slice) and don't output other (un-formatted) columns.
Set __getitem__(key) output type to torch for [] columns (when key is int or slice) and don't output other (un-formatted) columns.
It seems that for each split, set_format
is called again, but after that it is called again with empty column names - meaning that no columns will be included in the end. I am not sure why it is calling it with empty columns there. I thought that this might be a bug with the fingerprinting, but after clearing the cache this still occurs.