Hello everyone!
I am working with a large dataset of images which I am loading in Streaming mode to save disk space. Doing so, I end up with an IterableDataset.
Following this tutorial, I would like to display a grid of examples from each class, just to have a few data visual examples.
You can find all the code in the tutorial linked above, but this is the part of code I am interested in and I am having trouble with:
import random
from PIL import ImageDraw, ImageFont, Image
def show_examples(ds, seed: int = 1234, examples_per_class: int = 3, size=(350, 350)):
w, h = size
labels = ds['train'].features['labels'].names
grid = Image.new('RGB', size=(examples_per_class * w, len(labels) * h))
draw = ImageDraw.Draw(grid)
font = ImageFont.load_default()
for label_id, label in enumerate(labels):
# Filter the dataset by a single label, shuffle it, and grab a few samples
ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).shuffle(seed).select(range(examples_per_class))
# Plot this label's examples along a row
for i, example in enumerate(ds_slice):
image = example['image']
idx = examples_per_class * label_id + i
box = (idx % examples_per_class * w, idx // examples_per_class * h)
grid.paste(image.resize(size), box=box)
draw.text(box, label, (255, 255, 255), font=font)
return grid
show_examples(ds, seed=random.randint(0, 1337), examples_per_class=3)
I changed the slice part to:
ds_slice = ds.take(9)
But using take()
does not guarantee the presence of n samples from each class.
In short, my problem is that I cannot filter an IterableDataset by a single label and grab just a few examples for each label/class. Is there a way to do this or should I just give up?
Unfortunately I cannot provide my dataset for copyright reasons, but if you need more details I can explain more.
Thank you in advance!