=====> Colab reproducer <======
I’m using set_format('numpy')
for my dataset and using jax.numpy
ops to manipulate those numpy arrays.
Clearly, during debugging I can see that the shapes are perfectly what I expect when they go through their transformations via map
- however when I iterate over the dataset, then I’m getting un-batched arrays that are clearly 2D when they should be 3D.
Using .iter(batch_size=...)
slows it down a lot.
- There’s also the second problem that after the dataset is finished and yields no samples, the code just hangs for some reason. But that’s secondary to the main issue.