[Jax] Dataset does not return batched arrays

=====> Colab reproducer <======

I’m using set_format('numpy') for my dataset and using jax.numpy ops to manipulate those numpy arrays.

Clearly, during debugging I can see that the shapes are perfectly what I expect when they go through their transformations via map - however when I iterate over the dataset, then I’m getting un-batched arrays that are clearly 2D when they should be 3D.

Using .iter(batch_size=...) slows it down a lot.

  • There’s also the second problem that after the dataset is finished and yields no samples, the code just hangs for some reason. But that’s secondary to the main issue.

Hi ! The link to your google colab doesn’t work, can you share it again please ?