Hi,
Recently I tried to train an image captioning flax model with TPU, and I found it is very slow → each batch (256 images) requires 26 seconds. I finally found the cause:
My dataset have a column named “image” and it stores image array (being cached).
However, operations like batch = dataset[idx]
, where idx
is a list of length 256
will take 15 seconds, if the image size is (256 , 256)
. If I store the flattened image array (i.e. 1D array as length 256 * 256 * 3), it will be faster.
The provided script demonstrate this situation, with image/batch size 128.
I am wondering if I am doing something wrong, and if there is a better way to handle image datasets.
import numpy as np
import time
from datasets import load_dataset, Dataset
ds = load_dataset("cifar10")
train_ds = ds["train"]
# Take 512 examples
train_ds = train_ds.select(range(512))
# `_images` has shape (batch_size, 32, 32, 3)
# Let's make large images (batch_size, 32 * FACTOR, 32 * FACTOR, 3)
FACTOR = 4
def preprocess_function(examples):
inputs = {}
images = examples["img"]
# convert to np.float3f2 array
_images = np.array(images, dtype=np.float32)
# Make image larger
batch_size, image_size = _images.shape[0:2]
image_size = image_size * FACTOR
# use (H, W, C) format
_images = np.concatenate([_images] * FACTOR ** 2, axis=0).reshape((batch_size, image_size, image_size, 3))
# flatten images to 1D -> `data_loader` will be faster
# _images = _images.reshape((batch_size, image_size * image_size * 3))
inputs["image"] = [x for x in _images]
inputs["label"] = examples["label"]
return inputs
times_1 = []
times_2 = []
def data_loader(rng: np.random.Generator, dataset: Dataset, batch_size: int, shuffle: bool = False):
"""
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
Shuffle batches if `shuffle` is `True`.
"""
steps_per_epoch = len(dataset) // batch_size
if shuffle:
batch_idx = rng.permutation(len(dataset))
else:
batch_idx = range(len(dataset))
batch_idx = batch_idx[: steps_per_epoch * batch_size] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
for idx in batch_idx:
s = time.time()
# (bs=128, img_size=128)
# - (H, W, C) -> takes 6.959 seconds
# - flatten -> takes 2.407 seconds
batch = dataset[idx]
e = time.time()
times_1.append(e-s)
s = time.time()
# (bs=128, img_size=128)
# - (H, W, C) -> takes 1.661 seconds
# - flatten -> takes 0.673 seconds
batch = {k: np.array(v) for k, v in batch.items()}
e = time.time()
times_2.append(e-s)
yield batch
# The results are cached
train_ds = train_ds.map(
preprocess_function,
batched=True,
batch_size=16,
num_proc=2,
)
# Create sampling rng
input_rng = np.random.default_rng(seed=42)
train_loader = data_loader(input_rng, train_ds, batch_size=128, shuffle=True)
start = time.time()
for idx, batch in enumerate(train_loader):
end = time.time()
print(f"batch: {idx} | time: {end - start}")
start = time.time()
print(f"Average times_1 in data_loader: {np.mean(times_1)}")
print(f"Average times_2 in data_loader: {np.mean(times_2)}")