Deal with large image datasets

I have a raw dataset that contains image file paths. The preprocessing function will read these files and prepare features, which are numpy float32 arrays, with shape (Height, Width, 3). A direct use of datasets will cache the results. For 600K images of size (224, 224, 3), it takes > 350 GB. I plan to extend my work to CC3M or even CC12M datasets, and the required disk volume becomes too large for me.

I am wondering what’s a good approach to this. I don’t want to cache the whole results considering the disk usage, but I can’t keep it in memory neither.

My idea is:

  • On each epoch starts, shuffle the raw dataset.
  • Split the raw dataset to several small ones. And use map(keep_in_memory=True) on each of them in turn.
    • For each split, iterate over it in batches.

Is there better way to deal with large image datasets using datasets built-in methods? Or a better way then what I describe above.

Here is an example pseudo code:

# Think the datasets containing image file paths
# Assume we have 100M examples.
ds = load_dataset("...")
train_ds_raw = ds["train"]

EPOCHS = 3
SPLIT_SIZE = 512
SPLITS = len(train_ds_raw) // SPLIT_SIZE
for epoch in range(EPOCHS):

    # shuffle at each epoch start
    train_ds = train_ds_raw.shuffle()

    for idx in range(SPLITS):

        # Use `map()` on smaller portions, so we can keep map()'s results in memory instead of writing to cache
        start_idx = SPLIT_SIZE * idx
        end_idx = SPLIT_SIZE * (idx + 1)
        _train_ds = train_ds_raw.select(range(start_idx, end_idx))

        # Load image files from disk + feature processing: multiprocessing
        _train_ds = _train_ds.map(
            preprocess_function,
            batched=True,
            batch_size=16,
            num_proc=4,
            features=features,
            keep_in_memory=True,
        )
        _train_ds = _train_ds.with_format("numpy")

        # Create sampling rng
        input_rng = np.random.default_rng(seed=42)

        # Some more (fast) processing without multiprocessing
        # No need to shuffle here
        train_loader = data_loader(input_rng, _train_ds, batch_size=128, shuffle=False)

        # training
        for idx, batch in enumerate(train_loader):
            pass

Hi,

we are currently working on the Image feature, which will make this whole process faster.

In the meantime, you can use Dataset.with_transform and load the images on the fly when you need them. This way, only the current batch will occupy some memory. The only downside of this approach is that you’ll have to implement any sort of additional formatting by yourself in with_transform (you can also use DataLoader’s collate_fn for this) because with_format and with_transform currently cannot be chained.

1 Like