Slow DataLoader with big batch_size

Hi,

I’m investigating datasets for loading tabular data with pytorch. I’m having problems getting good performance when using both map and iterable style datasets:

from datasets import Dataset, load_from_disk
import time
from torch.utils.data import DataLoader

if __name__ == '__main__':
    DATASET_SIZE = 10_000_000
    BATCH_SIZE = 10_000
    MAX_NUM_ITER = 10

    ds = Dataset.from_dict({"idx": range(DATASET_SIZE)})
    ds.save_to_disk('test.hf')

    print('Loading dataset')
    ds = load_from_disk('test.hf')

    # map style dataset
    print('Running map style')
    data_loader = DataLoader(ds.with_format('torch'), batch_size=BATCH_SIZE)
    start = time.time()
    for i, batch in enumerate(iter(data_loader)):
        if i > MAX_NUM_ITER:
            break
    print((time.time() - start) / MAX_NUM_ITER)

    # iter style dataset
    print('Running iter style')
    iter_ds = ds.to_iterable_dataset()
    data_loader = DataLoader(iter_ds, batch_size=BATCH_SIZE)
    start = time.time()
    for i, batch in enumerate(iter(data_loader)):
        print(batch)
        if i > MAX_NUM_ITER:
            break
    print((time.time() - start) / MAX_NUM_ITER)

Results in

  • 0.146s per iteration for map style and
  • 0.686s per iteration for iterable style

If I understand the setup correctly (Differences between Dataset and IterableDataset) I should use the IterableDataset in combination with torch so I can do shuffling without getting speed issues:

However as soon as your Dataset has an indices mapping (via Dataset.shuffle()) for example), the speed can become 10x slower.

If you want to shuffle your dataset or use it with a PyTorch DataLoader, we recommend generating a sharded IterableDataset:

Is there any way to make dataloading faster? I’m getting much better results just using a numpy memory mapped file and directly loading the batches, i.e. something like

arr[start_idx:start_idx+BATCH_SIZE]

but atleast the IterableDataset should be doing the same since its just loading consecutive values from a memory mapped file.

Thanks, Lukas