We are trying to load datasets where the image column stores PIL.PngImagePlugin.PngImageFile
images. However, iterating over these datasets is extremely slow.
What I have found:
- It is the presence of the image column that causes the slowdown. Removing the column from the dataset results in blazingly fast (as expected) times
- It is ~2x faster to iterate when the column contains a single image as opposed to a list of images i.e., the feature is a Sequence of Image objects. We often need multiple images per sample, so we need to work with a list of images
- It is ~17x faster to store paths to PNG files and load them using
PIL.Image.open
, as opposed to iterating over aDataset
with an Image column, and ~30x faster compared toSequence
ofImage
s. See a simple script below with an openly available dataset.
It would be great to understand the standard practices for storing and loading multimodal datasets (image + text).
Load image data seems a bit underdeveloped? (e.g., dataset.decode
only works with IterableDataset
, but it’s not clear from the doc)
Thanks!
from datasets import load_dataset, load_from_disk
from PIL import Image
from pathlib import Path
ds = load_dataset("getomni-ai/ocr-benchmark")
for idx, sample in enumerate(ds["test"]):
image = sample["image"]
image.save(f"/tmp/ds_files/images/image_{idx}.png")
ds.save_to_disk("/tmp/ds_columns")
# Remove the 'image' column
ds["test"] = ds["test"].remove_columns(["image"])
# Create image paths for each sample
image_paths = [f"images/image_{idx}.png" for idx in range(len(ds["test"]))]
# Add the 'image_path' column to the dataset
ds["test"] = ds["test"].add_column("image_path", image_paths)
# Save the updated dataset
ds.save_to_disk("/tmp/ds_files")
files_path = Path("/tmp/ds_files")
column_path = Path("/tmp/ds_columns")
# load and benchmark
ds_file = load_from_disk(files_path)
ds_column = load_from_disk(column_path)
import time
images_files = []
start = time.time()
for idx in range(len(ds_file["test"])):
image_path = files_path / ds_file["test"][idx]["image_path"]
image = Image.open(image_path)
images_files.append(image)
end = time.time()
print(f"Time taken to load images from files: {end - start} seconds")
# Time taken to load images from files: 1.2364635467529297 seconds
images_column = []
start = time.time()
for idx in range(len(ds_column["test"])):
images_column.append(ds_column["test"][idx]["image"])
end = time.time()
print(f"Time taken to load images from columns: {end - start} seconds")
# Time taken to load images from columns: 20.49347186088562 seconds