JAX on CUDA? for FlaxStableDiffusionPipeline

I was curious about how the JAX implementation of Stable Diffusion compares to PyTorch for those of us who don’t have a TPU to play with at home. I’m trying to run it here:

jax_device = jax.local_devices(backend='gpu')[0]

pipe, params = diffusers.FlaxStableDiffusionPipeline.from_pretrained(
    feature_extractor=NullFeatureExtractor(),  # it's cranky if we use None here
params = jax.device_put(params, jax_device)

prompt_inputs = pipe.prepare_inputs("An astronaut riding JAX on Mars.")
result = pipe(

For some reason, params initially load to CPU instead of the GPU device, thus the device_put. But despite that I still get errors like this:

primitive arguments must be colocated on the same device (C++ jax.jit). Arguments are on devices: gpu:0 and TFRT_CPU_0

Things I’ve tried that don’t seem to help:

  • wrapping stuff in with jax.default_device(jax_device)
  • wrapping prompt_inputs in jnp.asarray
  • using revision="flax" instead of bf16. I’m not sure whether my CUDA device really supports bfloat16, but if I try to load the full-precision flax model, I run out of memory. (And I haven’t found any flax-fp16 model to load.)

If I leave out the device_put entirely, it does run on CPU, but that is super slow and not what I wanted to find out. (But it also uses all my vRAM? Confused about that.)

Why is it allocating some DeviceArrays on CPU? How do I find out which ones?

Thanks @keturn I’ll take a look!

Regarding memory consumption, JAX is very aggressive and reserves a huge chunk whenever it starts, but there are some ways to change this strategy: GPU memory allocation — JAX documentation

1 Like

JAX will preallocate 90% of currently-available GPU memory when the first JAX operation is run.

ahahahaha :joy: yup, that explains where all the GPU memory went even if it doesn’t seem like it’s actually using the GPU to compute. Thank you for that.

1 Like

Yes, it’s a bit distressing the first time you see it :slight_smile:

I’ll test on GPU later to try to understand what the problem might be.