prompt = "masterpiece, best quality, 1girl, at dusk"
neg_prompt = "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2)"
num_samples = jax.device_count()
pipeline, params = FlaxStableDiffusionImg2ImgPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
revision="flax",
dtype=jnp.bfloat16,
safety_checker=None,
load_attn_procs="sayakpaul/civitai-light-shadow-lora",
)
rng = create_key(0)
rng = jax.random.split(rng, num_samples)
prompt_ids, processed_image = pipeline.prepare_inputs(
prompt=[prompt] * num_samples,
image=[init_img] * num_samples,
)
neg_prompt_ids, processed_image = pipeline.prepare_inputs(
prompt=[neg_prompt] * num_samples,
image=[init_img] * num_samples,
)
p_params = replicate(params)
prompt_ids = shard(prompt_ids)
processed_image = shard(processed_image)
output = pipeline(
prompt_ids=prompt_ids,
image=processed_image,
neg_prompt_ids=neg_prompt_ids,
params=p_params,
prng_seed=rng,
strength=0.6,
num_inference_steps=50,
jit=True,
height=768,
width=512,
).images
I’ve tried generating neg_prompt_ids like prompt_ids because prepare_inputs only takes two args. But the above code will crash (without any error).