Num_worker with IterableDataset

For example, I have a dataset with 500 labels. Each label has 1-20G raw image data. For the training, I want each batch can randomly sample 256 images of the same label, and randomly select batch_size of labels of data.

Since the whole dataset is huge I created a Huggingface Dataset file for each label. etc.,
./dataset/label1/data-00000-of-00001.arrow, ./dataset/label2/data-00000-of-00001.arrow ...

I tried two methods to load the dataset to see if I can speed up:

gid = random.sample(self.labels, 1)[0]
files = get_arrow_files(os.path.join(self.dataset_path, gid))
data_files = {"train": files}
# ds = load_dataset("arrow", split='train', data_files=data_files, streaming=True)
ds = Dataset.from_file(files[0])
ds = ds.shuffle()
datas = []
for idx, batch in enumerate(ds):
   datas.append(batch)
    if (idx+1) == self.num_sample:
        break
return np.asarray(datas)

For the first one load_dataset("arrow", split='train', data_files=data_files, streaming=True), it can load faster than ds = Dataset.from_file(files[0]) for one batch, but it can’t use multi-workers to simultaneously load all batches since it returns a warning:

WARNING:datasets.iterable_dataset:Too many dataloader workers: 2 (max is dataset.n_shards=1). Stopping 1 dataloader workers.

For the second one ds = Dataset.from_file(files[0]), it can simultaneously load multiple arrow files, but it will be pretty slow if the arrow files are quite huge like 20G and the queue will be stacked until the first worker finish loading.

I’m wondering if there’s any solution that I can speed up the data loading process without changing the data sampling strategy. Thanks!

Iterable datasets are generally faster since they do contiguous reads of the data. And sharded iterable datasets can be read faster using multiple dataloader workers.

If you have N Arrow files that you pass as data_files to load_dataset(..., streaming=True) then you should end up with a sharded iterable dataset:

ds = load_dataset("arrow", split='train', data_files=data_files, streaming=True)
assert ds.n_shards == len(data_files["train"])

As soon as you have multiple shards, you can use multiple dataloader workers (up to ds.n_shards) to load data faster in parallel.

You can also get a sharded iterable dataset from one single file, but in this case you need to load a as a Dataset first:

ds = Dataset.from_file(file)
ds = ds.to_iterable_dataset(num_shards=8)
1 Like

If I have a sharded iterable dataset ds, and then use it with say,
DataLoader(ds, num_workers=8, ...)
(assuming n_shards > 8), this will result in 8 workers reading from and processing 8 different shards of ds, right?

I’m a bit confused here, because I thought when pulling items from a sharded iterable dataset, the shards are exhausted in serial. That is, one shard is completely consumed, before moving on to the next. So I don’t get how having multiple workers that handle different shards will speed things up in general.

If the shards are consumed in a round robin fashion, that would be another story. But I don’t think that is the case? I am definitely misunderstanding something here =(

In this case the shards are loaded in parallel by each worker (one shard per worker), and their contents are interleaved when you iterate on the data loader. That what speeds up data loading :slight_smile:

1 Like

Thanks for clarifying! And apologies - should’ve verified this myself:

from torch.utils.data import DataLoader

with open('./a.txt', 'w') as f:
    f.writelines(['a\n' for _ in range(20)])
with open('./b.txt', 'w') as f:
    f.writelines(['b\n' for _ in range(20)])

ds = load_dataset(
    'text', data_files=['./a.txt', './b.txt'], streaming=True, split=None
)['train']
assert ds.n_shards == 2

dl = DataLoader(ds, num_workers=2)
for item in dl:
    print(item)

a
b
a
b…