Cannot display examples from IterableDataset

Hello everyone!

I am working with a large dataset of images which I am loading in Streaming mode to save disk space. Doing so, I end up with an IterableDataset.
Following this tutorial, I would like to display a grid of examples from each class, just to have a few data visual examples.

You can find all the code in the tutorial linked above, but this is the part of code I am interested in and I am having trouble with:

import random
from PIL import ImageDraw, ImageFont, Image

def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):

    w, h = size
    labels = ds['train'].features['labels'].names
    grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
    draw = ImageDraw.Draw(grid)
    font = ImageFont.load_default() 

    for label_id, label in enumerate(labels):

        # Filter the dataset by a single label, shuffle it, and grab a few samples
        ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))

        # Plot this label's examples along a row
        for i, example in enumerate(ds_slice):
            image = example['image']
            idx = examples_per_class * label_id + i
            box = (idx % examples_per_class * w, idx // examples_per_class * h)
            grid.paste(image.resize(size), box=box)
            draw.text(box, label, (255, 255, 255), font=font)

    return grid

show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)

I changed the slice part to:

ds_slice = ds.take(9)

But using take() does not guarantee the presence of n samples from each class.

In short, my problem is that I cannot filter an IterableDataset by a single label and grab just a few examples for each label/class. Is there a way to do this or should I just give up?

Unfortunately I cannot provide my dataset for copyright reasons, but if you need more details I can explain more.

Thank you in advance!

First, make sure you keep keep the .filter(...) operation, then you can apply take

ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).take(examples_per_class)

and if you end up with less than examples_per_class it means that the filter scanned the full dataset and couldn’t find examples_per_class examples for a certain class.

1 Like

Thank you so much, it worked! :slight_smile:

Reference for possible future users, I have a 24GB image dataset (IterableDataset object, very unbalanced with respect to the three different classes) and running the above code (with the edit by @lhoestq) took about 20 minutes with examples_per_class=1… so maybe not the fastest way to see image examples.

I am satisfied for now, will update this post in case I manage to improve speed.
Thanks again and have a nice day!