The recommended way of using HuggingFace datasets with Jax is to do something like
ds = load_dataset("mnist").with_format("jax")
train_hf = ds["train"]
train_hf = train_hf.iter(batch_size=16)
If I wanted to do something similar with tensorflow, I might do
ds = tfds.load("mnist")
train_tf= ds["train"].batch(16)
train_tf = tfds.as_numpy(train_tf)
When I run %timeit -n 20 next(train_hf)
, it takes about 7ms per iteration. However, running %timeit -n 20 next(train_tf)
takes only 600 us. TensorFlow is somehow significantly faster than HF datasets – granted, TF is converting to numpy arrays, but for Jax users, I think most jax programs will accept numpy arrays just fine. Is this something that can be improved on, or do I have a problematic usage pattern?
Update: somewhat interestingly, jax.tree_util.tree_map
seems to be very lightweight. On the tensorflow example, running jax.tree_util.tree_map(jnp.asarray, next(train_tf))
doesn’t significantly change the runtime. That means that using tensorflow as an intermediary to get to jax is much faster, i.e., doing something like
ds = tfds.load("mnist")
train_tf= ds["train"].batch(16)
train_tf = tfds.as_numpy(train_tf)
for batch in train_tf:
jax.tree_util.tree_map(jnp.asarray, batch)
Similarly, using TF as an intermediary layer off of HF to get to Jax is actually faster than using with_format("jax")
, e.g.,
ds = load_dataset("mnist").with_format("tf")
train_hf = ds["train"].to_tf_dataset(batch_size=16)
train_hf = tfds.as_numpy(train_hf)
%timeit -n 20 jax.tree_util.tree_map(jnp.asarray, next(train_hf))
Entirely possible I’m doing something wrong, or neglecting some clever caching that TF is doing, so would appreciate any help on this!