FluxPipeline error while loading Flux.1 dev

Trying to run locally the FLUX.1-dev model as showed on the page I get the error

RuntimeError: Failed to import diffusers.pipelines.flux.pipeline_flux because of the following error (look up to see its traceback):
'Config' object has no attribute 'define_bool_state'

The code is exactly the same as the example

import torch
from diffusers import FluxPipeline

pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
    height=1024,
    width=1024,
    guidance_scale=3.5,
    num_inference_steps=50,
    max_sequence_length=512,
    generator=torch.Generator("gpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")

Checked diffusers and jax are already updated

Since dev is a gated model, the actual code should have a part that deals with tokens. schnell doesn’t need it, so try it with schnell first.
There’s quite a bit of lying in the Examples in general.

black-forest-labs/FLUX.1-schnell

I tried with the sample code and I’m getting the same error

RuntimeError: Failed to import diffusers.pipelines.flux.pipeline_flux because of the following error (look up to see its traceback):
'Config' object has no attribute 'define_bool_state'

Normally, if your VRAM is enough, the following code should work, but I’ve searched and it may be a trickier error than I thought.
It means that Diffusers might not be the only cause.

I’m starting to think that the problem is a pattern of a mis-installed CUDA library and CUDA-compatible torch, or some other library that is doing something wrong.

For example, do SD1.5 and SDXL models work with the same code?
If you just replace the model name part of the code below, it should work in general.

pip install -U diffusers
import torch
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power

prompt = "A cat holding a sign that says hello world"
image = pipe(
    prompt,
).images[0]
image.save("flux-schnell.png")

Indeed it seems to be an issue with Flax .

If so, there are three possibilities.

  • Flux model size is too large and not enough RAM or VRAM.
  • There is an environment-dependent bug in Diffusers’ FluxPipeline
  • Outdated version of Diffusers (uninstalling and then installing Diffusers may fix this)

If it is a lack of VRAM, the following NF4 model may solve the problem.

Hi,

Could you remove Jax from your environment? pip uninstall jax. Alternatively, you could try upgrading jax, pip install --upgrade jax.

The code snippet only requires PyTorch and Diffusers to be installed.