If I do something like
from datasets import load_dataset
ds = (load_dataset("mnist").with_format('torch')
.map(lambda ex: {"image": ex['image'].view(-1), "label": ex['label']})
)
ds["train"]["image"].shape
The last line takes an oddly long time (~10 s). I can’t imagine why this would be the case. Still trying to get a handle on the huggingface datasets library, so any help would be appreciated!