Tensorflow datasets -> numpy is 10x faster than Jax HF datasets

The recommended way of using HuggingFace datasets with Jax is to do something like

ds = load_dataset("mnist").with_format("jax")
train_hf = ds["train"]
train_hf = train_hf.iter(batch_size=16)

If I wanted to do something similar with tensorflow, I might do

ds = tfds.load("mnist")
train_tf= ds["train"].batch(16)
train_tf = tfds.as_numpy(train_tf)

When I run %timeit -n 20 next(train_hf), it takes about 7ms per iteration. However, running %timeit -n 20 next(train_tf) takes only 600 us. TensorFlow is somehow significantly faster than HF datasets – granted, TF is converting to numpy arrays, but for Jax users, I think most jax programs will accept numpy arrays just fine. Is this something that can be improved on, or do I have a problematic usage pattern?

Update: somewhat interestingly, jax.tree_util.tree_map seems to be very lightweight. On the tensorflow example, running jax.tree_util.tree_map(jnp.asarray, next(train_tf)) doesn’t significantly change the runtime. That means that using tensorflow as an intermediary to get to jax is much faster, i.e., doing something like

ds = tfds.load("mnist")
train_tf= ds["train"].batch(16)
train_tf = tfds.as_numpy(train_tf)
for batch in train_tf:
    jax.tree_util.tree_map(jnp.asarray, batch)

Similarly, using TF as an intermediary layer off of HF to get to Jax is actually faster than using with_format("jax"), e.g.,

ds = load_dataset("mnist").with_format("tf")
train_hf = ds["train"].to_tf_dataset(batch_size=16)
train_hf = tfds.as_numpy(train_hf)
%timeit -n 20 jax.tree_util.tree_map(jnp.asarray, next(train_hf))

Entirely possible I’m doing something wrong, or neglecting some clever caching that TF is doing, so would appreciate any help on this!

mnist is read from numpy array buffers in TFDS, but from PNG files in HF Datasets.

The speed different is particular to MNIST here, which is so tiny that is it saved in raw arrays - not as image files. HF Datasets decodes PNG files, while TFDS just reads numpy arrays buffers.

That makes sense, but in that case, why is using tensorflow as an intermediary (huggingface → tensorflow → jax) still faster?

Could you share the code you used to get that and your results ? In your last code snippet you get the first batch over and over, which is an inaccurate way to benchmark this (querying the same data in RAM + a possible caching)