set_format('numpy') for my dataset and using
jax.numpy ops to manipulate those numpy arrays.
Clearly, during debugging I can see that the shapes are perfectly what I expect when they go through their transformations via
map - however when I iterate over the dataset, then I’m getting un-batched arrays that are clearly 2D when they should be 3D.
.iter(batch_size=...) slows it down a lot.
- There’s also the second problem that after the dataset is finished and yields no samples, the code just hangs for some reason. But that’s secondary to the main issue.