Using an IterableDataset’s approximate shuffle appears to convert data from float32 to float64, despite the original data being of type float32, is this intentional behaviour? Here’s a sketch:
ds = ds.to_iterable_dataset(num_shards=n_shards)
ds = ds.with_format('numpy')
train_loader = ds.iter(batch_size, drop_last_batch=True)
for i, x in enumerate(train_loader):
# all ok, data is float32
ds = ds.shuffle(seed=seed + epoch, buffer_size=10000)
ds = ds.with_format('numpy')
train_loader = ds.iter(batch_size, drop_last_batch=True)
for i, x in enumerate(train_loader):
# not ok, data is float64