Is Jax faster than Pytorch XLA?

i am using accelerate for inferencing models with GPU. and as per the documentation. accelerate uses Pytorch XLA for backend.

is jax a better alternative in tems of speed, if we use TPU or GPU?

and if it is faster, can we have a JAX backend for accelerate?

A backend yes, we have I think 6 backends in total now (xla, cuda, mpu, etc).

You cannot use JAX as a backend, accelerate is pytorch only