I was curious about how the JAX implementation of Stable Diffusion compares to PyTorch for those of us who don’t have a TPU to play with at home. I’m trying to run it here:
jax_device = jax.local_devices(backend='gpu')[0]
pipe, params = diffusers.FlaxStableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="bf16",
safety_checker=None,
feature_extractor=NullFeatureExtractor(), # it's cranky if we use None here
dtype=jnp.bfloat16,
)
params = jax.device_put(params, jax_device)
prompt_inputs = pipe.prepare_inputs("An astronaut riding JAX on Mars.")
result = pipe(
prompt_ids=prompt_inputs,
num_inference_steps=12,
params=params,
prng_seed=jax.random.PRNGKey(0)
)
For some reason, params initially load to CPU instead of the GPU device, thus the device_put
. But despite that I still get errors like this:
primitive arguments must be colocated on the same device (C++ jax.jit). Arguments are on devices: gpu:0 and TFRT_CPU_0
Things I’ve tried that don’t seem to help:
- wrapping stuff in
with jax.default_device(jax_device)
- wrapping
prompt_inputs
injnp.asarray
- using
revision="flax"
instead ofbf16
. I’m not sure whether my CUDA device really supports bfloat16, but if I try to load the full-precision flax model, I run out of memory. (And I haven’t found anyflax-fp16
model to load.)
If I leave out the device_put
entirely, it does run on CPU, but that is super slow and not what I wanted to find out. (But it also uses all my vRAM? Confused about that.)
Why is it allocating some DeviceArrays on CPU? How do I find out which ones?