Load_dataset using arrow datafiles + streaming gets an index error with the pytorch Dataloader

I have created the following function to load splits of a large 1TB dataset (using Datasets version: 2.17.0):

def load_embeddings(split):
    split_path = os.path.join(dataset_path, split)
    arrow_files = [os.path.join(split_path, f) for f in os.listdir(split_path) if f.endswith('.arrow')]
    
    if not arrow_files:
        raise ValueError(f"No .arrow files found in the specified split path: {split_path}")

    # Load the dataset using the arrow files collected for the specified split
    dataset = load_dataset("arrow", data_files=arrow_files, streaming=True)

    return dataset

When I use a DataLoader like so:

# Usage example for train, validation, and test splits
print("Started Reading Dataset")
train_dataset = load_embeddings("train")
test_dataset = load_embeddings("test")
val_dataset = load_embeddings("validation")
print("Dataset Read")

collator = DataCollatorForMyModel()

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, collate_fn=collator)
test_loader = val_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, collate_fn=collator)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, collate_fn=collator)

During training the following error happens:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[13], line 37
     35 # Call the training function
     36 print("Training Transformer Bridge")
---> 37 train_model(model, train_loader, optimizer, scheduler, input_model, output_model, writer, val_loader, test_loader)
     39 #Test Bridge
     40 print("Testing Transformer Bridge")

Cell In[8], line 73
     71 checkpoint_count = 1
     72 while step_count <= target_total_steps:
---> 73     for batch in train_loader:
     74         if step_count >= target_total_steps:
     75             break  # Break the loop if target steps reached

File~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/dataloader.py:630, in _BaseDataLoaderIter.__next__(self)
    627 if self._sampler_iter is None:
    628     # TODO(https://github.com/pytorch/pytorch/issues/76750)
    629     self._reset()  # type: ignore[call-arg]
--> 630 data = self._next_data()
    631 self._num_yielded += 1
    632 if self._dataset_kind == _DatasetKind.Iterable and \
    633         self._IterableDataset_len_called is not None and \
    634         self._num_yielded > self._IterableDataset_len_called:

File~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/dataloader.py:674, in _SingleProcessDataLoaderIter._next_data(self)
    672 def _next_data(self):
    673     index = self._next_index()  # may raise StopIteration
--> 674     data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    675     if self._pin_memory:
    676         data = _utils.pin_memory.pin_memory(data, self._pin_memory_device)

File ~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in _MapDatasetFetcher.fetch(self, possibly_batched_index)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

File ~redacted/temp/9701250/conda_environment/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py:51, in <listcomp>(.0)
     49         data = self.dataset.__getitems__(possibly_batched_index)
     50     else:
---> 51         data = [self.dataset[idx] for idx in possibly_batched_index]
     52 else:
     53     data = self.dataset[possibly_batched_index]

KeyError: 0

This error is not something to happens when I use a more traditional streaming dataset loading approach like so:

load_dataset(path).to_iterable_dataset(num_shards=64)

I am not sure why for batch in train_loader attempts to access a specific index in this case.

Loading the dataset without streaming isn’t really an option as it takes 4hrs+ to just load.

After more testing, I have found this only seems to happen when the dataset is a IterableDatasetDict and not when it is a IterableDataset.

If anyone finds this and is looking for a fix even if you directly load the arrow files for an individual split like I did the output will be a DatasetDict with a single split ‘train’. So you need to select the train split for it to work.