i am using accelerate for inferencing models with GPU. and as per the documentation. accelerate uses Pytorch XLA for backend.
is jax a better alternative in tems of speed, if we use TPU or GPU?
and if it is faster, can we have a JAX backend for accelerate?