[Solved] Image dataset seems slow for larger image size

Hi, @lhoestq

(using datasets==1.16.1)

I can’t reproduce your super good results. What I get is

(I provided below the complete script with your optimizations)

with your optimizations:

batch: 0 | time: 2.873563528060913
batch: 1 | time: 2.5054991245269775
batch: 2 | time: 2.809781312942505
batch: 3 | time: 2.65997314453125
Average times_1 in data_loader: 2.6919819116592407
Average times_2 in data_loader: 0.01998382806777954

without optimizations:

batch: 0 | time: 8.33098292350769
batch: 1 | time: 7.9121692180633545
batch: 2 | time: 8.10653042793274
batch: 3 | time: 7.925087928771973
Average times_1 in data_loader: 6.770546138286591
Average times_2 in data_loader: 1.2979284524917603

Whenever you get some time, could you help me to get your super results, please?
Thank you!

The script

import numpy as np
import time

import datasets
from datasets import load_dataset, Dataset


ds = load_dataset("cifar10")
train_ds = ds["train"]
# Take 512 examples
train_ds = train_ds.select(range(512))


# `_images` has shape (batch_size, 32, 32, 3)
# Let's make large images (batch_size, 32 * FACTOR, 32 * FACTOR, 3)
FACTOR = 4


def preprocess_function(examples):

    inputs = {}
    images = examples["img"]

    # convert to np.float3f2 array
    _images = np.array(images, dtype=np.int32)

    # Make image larger
    batch_size, image_size = _images.shape[0:2]

    image_size = image_size * FACTOR

    # use (H, W, C) format
    _images = np.concatenate([_images] * FACTOR ** 2, axis=0).reshape((batch_size, image_size, image_size, 3))
    # flatten images to 1D -> `data_loader` will be faster
    # _images = _images.reshape((batch_size, image_size * image_size * 3))

    inputs["image"] = _images

    return inputs


times_1 = []
times_2 = []

features = datasets.Features({
    "image": datasets.Array3D(dtype="int32", shape=(32 * FACTOR, 32 * FACTOR, 3))
})


def data_loader(rng: np.random.Generator, dataset: Dataset, batch_size: int, shuffle: bool = False):
    """
    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
    Shuffle batches if `shuffle` is `True`.
    """

    steps_per_epoch = len(dataset) // batch_size

    if shuffle:
        batch_idx = rng.permutation(len(dataset))
    else:
        batch_idx = np.arange(len(dataset))

    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

    for idx in batch_idx:

        s = time.time()
        # (bs=128, img_size=128)
        #   - (H, W, C) -> takes 6.959 seconds
        #   - flatten   -> takes 2.407 seconds
        batch = dataset[idx]
        e = time.time()
        times_1.append(e-s)

        s = time.time()
        # (bs=128, img_size=128)
        #   - (H, W, C) -> takes 1.661 seconds
        #   - flatten   -> takes 0.673 seconds
        batch = {k: np.array(v) for k, v in batch.items()}
        e = time.time()
        times_2.append(e-s)

        yield batch


# The results are cached
train_ds = train_ds.map(
    preprocess_function,
    remove_columns=["img", "label"],
    batched=True,
    batch_size=16,
    num_proc=2,
    features=features,
)
train_ds = train_ds.with_format("numpy")


# Create sampling rng
input_rng = np.random.default_rng(seed=42)
train_loader = data_loader(input_rng, train_ds, batch_size=128, shuffle=True)

start = time.time()
for idx, batch in enumerate(train_loader):

    end = time.time()
    print(f"batch: {idx} | time: {end - start}")
    start = time.time()


print(f"Average times_1 in data_loader: {np.mean(times_1)}")
print(f"Average times_2 in data_loader: {np.mean(times_2)}")