First, make sure you keep keep the .filter(...)
operation, then you can apply take
ds_slice = ds['train'].filter(lambda ex: ex['labels'] == label_id).take(examples_per_class)
and if you end up with less than examples_per_class
it means that the filter scanned the full dataset and couldn’t find examples_per_class
examples for a certain class.