[Solved] Image dataset seems slow for larger image size

Hi,

Recently I tried to train an image captioning flax model with TPU, and I found it is very slow → each batch (256 images) requires 26 seconds. I finally found the cause:

My dataset have a column named “image” and it stores image array (being cached).
However, operations like batch = dataset[idx], where idx is a list of length 256 will take 15 seconds, if the image size is (256 , 256). If I store the flattened image array (i.e. 1D array as length 256 * 256 * 3), it will be faster.

The provided script demonstrate this situation, with image/batch size 128.

I am wondering if I am doing something wrong, and if there is a better way to handle image datasets.

import numpy as np
import time

from datasets import load_dataset, Dataset


ds = load_dataset("cifar10")
train_ds = ds["train"]
# Take 512 examples
train_ds = train_ds.select(range(512))


# `_images` has shape (batch_size, 32, 32, 3)
# Let's make large images (batch_size, 32 * FACTOR, 32 * FACTOR, 3)
FACTOR = 4


def preprocess_function(examples):

    inputs = {}
    images = examples["img"]

    # convert to np.float3f2 array
    _images = np.array(images, dtype=np.float32)

    # Make image larger
    batch_size, image_size = _images.shape[0:2]

    image_size = image_size * FACTOR

    # use (H, W, C) format
    _images = np.concatenate([_images] * FACTOR ** 2, axis=0).reshape((batch_size, image_size, image_size, 3))
    # flatten images to 1D -> `data_loader` will be faster
    # _images = _images.reshape((batch_size, image_size * image_size * 3))

    inputs["image"] = [x for x in _images]
    inputs["label"] = examples["label"]

    return inputs


times_1 = []
times_2 = []


def data_loader(rng: np.random.Generator, dataset: Dataset, batch_size: int, shuffle: bool = False):
    """
    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
    Shuffle batches if `shuffle` is `True`.
    """

    steps_per_epoch = len(dataset) // batch_size

    if shuffle:
        batch_idx = rng.permutation(len(dataset))
    else:
        batch_idx = range(len(dataset))

    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

    for idx in batch_idx:

        s = time.time()
        # (bs=128, img_size=128)
        #   - (H, W, C) -> takes 6.959 seconds
        #   - flatten   -> takes 2.407 seconds
        batch = dataset[idx]
        e = time.time()
        times_1.append(e-s)

        s = time.time()
        # (bs=128, img_size=128)
        #   - (H, W, C) -> takes 1.661 seconds
        #   - flatten   -> takes 0.673 seconds
        batch = {k: np.array(v) for k, v in batch.items()}
        e = time.time()
        times_2.append(e-s)

        yield batch


# The results are cached
train_ds = train_ds.map(
    preprocess_function,
    batched=True,
    batch_size=16,
    num_proc=2,
)

# Create sampling rng
input_rng = np.random.default_rng(seed=42)
train_loader = data_loader(input_rng, train_ds, batch_size=128, shuffle=True)

start = time.time()
for idx, batch in enumerate(train_loader):

    end = time.time()
    print(f"batch: {idx} | time: {end - start}")
    start = time.time()


print(f"Average times_1 in data_loader: {np.mean(times_1)}")
print(f"Average times_2 in data_loader: {np.mean(times_2)}")

Hi !
There are two optimizations that you can use:

  1. Use the Array3D feature type. Otherwise it will consider your arrays to be lists of arbitrary sizes, and it takes some time to collate them again into one numpy array.
features = Features({
    **train_ds.features,
    "image": Array3D(dtype="int32", shape=(128, 128, 3))
})
train_ds = train_ds.map(
    preprocess_function,
    batched=True,
    batch_size=16,
    num_proc=2,
    features=features
)
  1. Use the ‘numpy’ format. This way when accessing examples they will already be numpy arrays, and you will also benefit from the end-to-end zero-copy array read from Arrow to numpy.
train_ds = train_ds.with_format("numpy")

It becomes 100 times faster when querying the data loader !

Here are my results without the optimizations:

Average times_1 in data_loader: 7.690094709396362

And then, with these optimizations:

Average times_1 in data_loader: 0.0582084059715271
8 Likes

@lhoestq , this is soooo great! Thank you.

I should learn datasets more seriously :slight_smile:

BTW, after a quick search, I didn’t find such techniques mentioned in Datasets — datasets 1.14.0 documentation.

Maybe it is good idea to add a section about such optimizations? :slight_smile:

5 Likes

Hi, @lhoestq

(using datasets==1.16.1)

I can’t reproduce your super good results. What I get is

(I provided below the complete script with your optimizations)

with your optimizations:

batch: 0 | time: 2.873563528060913
batch: 1 | time: 2.5054991245269775
batch: 2 | time: 2.809781312942505
batch: 3 | time: 2.65997314453125
Average times_1 in data_loader: 2.6919819116592407
Average times_2 in data_loader: 0.01998382806777954

without optimizations:

batch: 0 | time: 8.33098292350769
batch: 1 | time: 7.9121692180633545
batch: 2 | time: 8.10653042793274
batch: 3 | time: 7.925087928771973
Average times_1 in data_loader: 6.770546138286591
Average times_2 in data_loader: 1.2979284524917603

Whenever you get some time, could you help me to get your super results, please?
Thank you!

The script

import numpy as np
import time

import datasets
from datasets import load_dataset, Dataset


ds = load_dataset("cifar10")
train_ds = ds["train"]
# Take 512 examples
train_ds = train_ds.select(range(512))


# `_images` has shape (batch_size, 32, 32, 3)
# Let's make large images (batch_size, 32 * FACTOR, 32 * FACTOR, 3)
FACTOR = 4


def preprocess_function(examples):

    inputs = {}
    images = examples["img"]

    # convert to np.float3f2 array
    _images = np.array(images, dtype=np.int32)

    # Make image larger
    batch_size, image_size = _images.shape[0:2]

    image_size = image_size * FACTOR

    # use (H, W, C) format
    _images = np.concatenate([_images] * FACTOR ** 2, axis=0).reshape((batch_size, image_size, image_size, 3))
    # flatten images to 1D -> `data_loader` will be faster
    # _images = _images.reshape((batch_size, image_size * image_size * 3))

    inputs["image"] = _images

    return inputs


times_1 = []
times_2 = []

features = datasets.Features({
    "image": datasets.Array3D(dtype="int32", shape=(32 * FACTOR, 32 * FACTOR, 3))
})


def data_loader(rng: np.random.Generator, dataset: Dataset, batch_size: int, shuffle: bool = False):
    """
    Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
    Shuffle batches if `shuffle` is `True`.
    """

    steps_per_epoch = len(dataset) // batch_size

    if shuffle:
        batch_idx = rng.permutation(len(dataset))
    else:
        batch_idx = np.arange(len(dataset))

    batch_idx = batch_idx[: steps_per_epoch * batch_size]  # Skip incomplete batch.
    batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))

    for idx in batch_idx:

        s = time.time()
        # (bs=128, img_size=128)
        #   - (H, W, C) -> takes 6.959 seconds
        #   - flatten   -> takes 2.407 seconds
        batch = dataset[idx]
        e = time.time()
        times_1.append(e-s)

        s = time.time()
        # (bs=128, img_size=128)
        #   - (H, W, C) -> takes 1.661 seconds
        #   - flatten   -> takes 0.673 seconds
        batch = {k: np.array(v) for k, v in batch.items()}
        e = time.time()
        times_2.append(e-s)

        yield batch


# The results are cached
train_ds = train_ds.map(
    preprocess_function,
    remove_columns=["img", "label"],
    batched=True,
    batch_size=16,
    num_proc=2,
    features=features,
)
train_ds = train_ds.with_format("numpy")


# Create sampling rng
input_rng = np.random.default_rng(seed=42)
train_loader = data_loader(input_rng, train_ds, batch_size=128, shuffle=True)

start = time.time()
for idx, batch in enumerate(train_loader):

    end = time.time()
    print(f"batch: {idx} | time: {end - start}")
    start = time.time()


print(f"Average times_1 in data_loader: {np.mean(times_1)}")
print(f"Average times_2 in data_loader: {np.mean(times_2)}")

Hi ! We recently did some optimizations in datasets regarding the way we handle arrays.

Note that for vision datasets like cifar10, we’re adding support for efficiently loading the images directly as PIL images rather than keeping raw lists of integers. This will be available in the next version of datasets. The documentation is also on its way :slight_smile:

Hi, looking forward to the new optimizations!

However, I use cifar10 to easily showing what difficulty I have.
In fact, I am working on a PR for image captioning in Flax, and it will use the feature extractor of the ViT model. This feature extraction will convert RGB image into float32 3D array. And when I want to index into a preprocessed dataset (that uses this feature extractor) to get a batch, it takes quite long time.

The dataset I have is a collection of local images. It would be great if I can make it work without this slowness of processing.

For example, to get a batch of 256 elements in a preprocessed datasets, it takes about 5 seconds.

If you have a clear idea, that would be great. Otherwise, I understand you have other priorities :slight_smile:

@lhoestq

I installed by pip install --upgrade .[dev], and fixed a minor issue of types:
it works now!!

You guys are amazing!! Thank you so much.

By the way, would you mind to tell me which PR did this optimization. I want to get some ideas about it! Thanks!


Average times_1 in data_loader: 0.12009784579277039

batch: 0 | time: 0.11190962791442871
batch: 1 | time: 0.14724326133728027
batch: 2 | time: 0.12398481369018555
batch: 3 | time: 0.12642168998718262
batch: 4 | time: 0.08752894401550293
batch: 5 | time: 0.09127330780029297
batch: 6 | time: 0.09010076522827148
batch: 7 | time: 0.09955096244812012
batch: 8 | time: 0.11309647560119629
batch: 9 | time: 0.09490156173706055
batch: 10 | time: 0.0921792984008789
batch: 11 | time: 0.10338354110717773
batch: 12 | time: 0.10463619232177734
batch: 13 | time: 0.11239099502563477
batch: 14 | time: 0.11689925193786621
batch: 15 | time: 0.10281705856323242
batch: 16 | time: 0.10169291496276855
batch: 17 | time: 0.15300273895263672
batch: 18 | time: 0.10869407653808594
batch: 19 | time: 0.08719658851623535
batch: 20 | time: 0.1176912784576416
batch: 21 | time: 0.13638067245483398
batch: 22 | time: 0.12096691131591797
batch: 23 | time: 0.08860659599304199
batch: 24 | time: 0.10601067543029785
batch: 25 | time: 0.11217999458312988
batch: 26 | time: 0.1530930995941162
batch: 27 | time: 0.14607572555541992
batch: 28 | time: 0.17520356178283691
batch: 29 | time: 0.2369546890258789
batch: 30 | time: 0.13903379440307617
batch: 31 | time: 0.14586281776428223