Hello all, I was trying to use MNIST datasets and flatten the image, but I found it gives me extra dimension when I don’t add a new key. It is not a big deal but I am curious about why this happens. Really appreciate!
ds = load_dataset("mnist",split = 'train[:10]').with_format('torch')
def transforms(examples):
examples['new_image'] = [image.reshape(-1) for image in examples['image']]
return examples
new_ds = ds.map(transforms)
new_ds['new_image'].shape #torch.Size([10, 1, 784]) which is correct
def transforms_2(examples):
examples['image'] = [image.reshape(-1) for image in examples['image']]
return examples
new_ds2 = ds.map(transforms_2)
new_ds2['image'].shape #torch.Size([10, 1, 1, 784]) which has an extra dimension