I have created the following function to load splits of a large 1TB dataset (using Datasets version: 2.17.0):
def load_embeddings(split):
split_path = os.path.join(dataset_path, split)
arrow_files = [os.path.join(split_path, f) for f in os.listdir(split_path) if f.endswith('.arrow')]
if not arrow_files:
raise ValueError(f"No .arrow files found in the specified split path: {split_path}")
# Load the dataset using the arrow files collected for the specified split
dataset = load_dataset("arrow", data_files=arrow_files, streaming=True)
return dataset
When I use a DataLoader like so:
# Usage example for train, validation, and test splits
print("Started Reading Dataset")
train_dataset = load_embeddings("train")
test_dataset = load_embeddings("test")
val_dataset = load_embeddings("validation")
print("Dataset Read")
collator = DataCollatorForMyModel()
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=collator)
test_loader = val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, collate_fn=collator)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, collate_fn=collator)
During training the following error happens:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[13], line 37
35 # Call the training function
36 print("Training Transformer Bridge")
---> 37 train_model(model, train_loader, optimizer, scheduler, input_model, output_model, writer, val_loader, test_loader)
39 #Test Bridge
40 print("Testing Transformer Bridge")
Cell In[8], line 73
71 checkpoint_count = 1
72 while step_count <= target_total_steps:
---> 73 for batch in train_loader:
74 if step_count >= target_total_steps:
75 break # Break the loop if target steps reached
File~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
627 if self._sampler_iter is None:
628 # TODO(https://github.com/pytorch/pytorch/issues/76750)
629 self._reset() # type: ignore[call-arg]
--> 630 data = self._next_data()
631 self._num_yielded += 1
632 if self._dataset_kind == _DatasetKind.Iterable and \
633 self._IterableDataset_len_called is not None and \
634 self._num_yielded > self._IterableDataset_len_called:
File~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
672 def _next_data(self):
673 index = self._next_index() # may raise StopIteration
--> 674 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
675 if self._pin_memory:
676 data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)
File ~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
File ~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in <listcomp>(.0)
49 data = self.dataset.__getitems__(possibly_batched_index)
50 else:
---> 51 data = [self.dataset[idx] for idx in possibly_batched_index]
52 else:
53 data = self.dataset[possibly_batched_index]
KeyError: 0
This error is not something to happens when I use a more traditional streaming dataset loading approach like so:
load_dataset(path).to_iterable_dataset(num_shards=64)
I am not sure why for batch in train_loader
attempts to access a specific index in this case.
Loading the dataset without streaming isnāt really an option as it takes 4hrs+ to just load.