The recommended way of using HuggingFace datasets with Jax is to do something like
ds = load_dataset("mnist").with_format("jax") train_hf = ds["train"] train_hf = train_hf.iter(batch_size=16)
If I wanted to do something similar with tensorflow, I might do
ds = tfds.load("mnist") train_tf= ds["train"].batch(16) train_tf = tfds.as_numpy(train_tf)
When I run
%timeit -n 20 next(train_hf), it takes about 7ms per iteration. However, running
%timeit -n 20 next(train_tf) takes only 600 us. TensorFlow is somehow significantly faster than HF datasets – granted, TF is converting to numpy arrays, but for Jax users, I think most jax programs will accept numpy arrays just fine. Is this something that can be improved on, or do I have a problematic usage pattern?
Update: somewhat interestingly,
jax.tree_util.tree_map seems to be very lightweight. On the tensorflow example, running
jax.tree_util.tree_map(jnp.asarray, next(train_tf)) doesn’t significantly change the runtime. That means that using tensorflow as an intermediary to get to jax is much faster, i.e., doing something like
ds = tfds.load("mnist") train_tf= ds["train"].batch(16) train_tf = tfds.as_numpy(train_tf) for batch in train_tf: jax.tree_util.tree_map(jnp.asarray, batch)
Similarly, using TF as an intermediary layer off of HF to get to Jax is actually faster than using
ds = load_dataset("mnist").with_format("tf") train_hf = ds["train"].to_tf_dataset(batch_size=16) train_hf = tfds.as_numpy(train_hf) %timeit -n 20 jax.tree_util.tree_map(jnp.asarray, next(train_hf))
Entirely possible I’m doing something wrong, or neglecting some clever caching that TF is doing, so would appreciate any help on this!