gdlm
September 3, 2024, 8:59pm
1
Trying to run locally the FLUX.1-dev model as showed on the page I get the error
RuntimeError: Failed to import diffusers.pipelines.flux.pipeline_flux because of the following error (look up to see its traceback):
'Config' object has no attribute 'define_bool_state'
The code is exactly the same as the example
import torch
from diffusers import FluxPipeline
pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
height=1024,
width=1024,
guidance_scale=3.5,
num_inference_steps=50,
max_sequence_length=512,
generator=torch.Generator("gpu").manual_seed(0)
).images[0]
image.save("flux-dev.png")
Checked diffusers and jax are already updated
Since dev is a gated model, the actual code should have a part that deals with tokens. schnell doesnβt need it, so try it with schnell first.
Thereβs quite a bit of lying in the Examples in general.
black-forest-labs/FLUX.1-schnell
gdlm
September 6, 2024, 4:32pm
3
I tried with the sample code and Iβm getting the same error
RuntimeError: Failed to import diffusers.pipelines.flux.pipeline_flux because of the following error (look up to see its traceback):
'Config' object has no attribute 'define_bool_state'
opened 01:20PM - 05 Jul 23 UTC
Priority: P2 - no schedule
### System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04β¦ ): tpu-vm-base
- Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`: flax: 0.6.7, jax: 0.4.13, jaxlib: 0.4.13
- Python version: 3.8
- GPU/TPU model and memory: tpu v3-8
- CUDA version (if applicable): N/A
### Problem you have encountered:
Hello, I am seeing
```
AttributeError: module 'jax.config' has no attribute 'define_bool_state'
```
when running the `big_vision` library on `tpu-vm-base`.
I saw a similar issue in [another library](https://github.com/google-research/scenic/issues/816) where the issue seems to be resolved by fixing `jax` version to `0.4.9`, but when I attempted that it did not work.
I also tried fixing the versions of all packages in the `requirements.txt` of `big_vision`, i.e.
```
absl-py==1.4.0
clu==0.0.8
einops==0.6.0
flax==0.6.7
git+https://github.com/google/flaxformer
git+https://github.com/deepmind/optax.git
git+https://github.com/akolesnikoff/panopticapi.git@mute
overrides==7.3.1
tensorflow==2.12.0
tfds-nightly==4.8.3.dev202303250044
tensorflow-addons==0.19.0
tensorflow-text==2.12.0
tensorflow-gan==2.1.0
```
at the same time when fixing `jax` to `0.4.9`, but that did not work either.
I had to use a full `requirements.txt` obtained from running `pip freeze` in a local venv created on 2023-03-26 to get the library running on TPU again.
### What you expected to happen:
I was able to run `big_vision` on `tpu-vm-base` on a `v3-8` TPU node without fixing any package versions as late as 2023-05-24.
### Logs, error messages, etc:
```
Installing collected packages: libtpu-nightly, zipp, numpy, scipy, opt-einsum, ml-dtypes, importlib-metadata, jaxlib, jax
Successfully installed importlib-metadata-6.7.0 jax-0.4.13 jaxlib-0.4.13 libtpu-nightly-0.1.dev20230622 ml-dtypes-0.2.0 numpy-1.24.4 opt-einsum-3.3.0 scipy-1.10.1 zipp-3.15.0
Collecting git+https://github.com/google/flaxformer (from -r big_vision/requirements.txt (line 5))
Cloning https://github.com/google/flaxformer to /tmp/pip-req-build-925ai1ze
Running command git clone --filter=blob:none --quiet https://github.com/google/flaxformer /tmp/pip-req-build-925ai1ze
Resolved https://github.com/google/flaxformer to commit 9adaa4467cf17703949b9f537c3566b99de1b416
Installing build dependencies: started
Installing build dependencies: finished with status 'done'
Getting requirements to build wheel: started
Getting requirements to build wheel: finished with status 'done'
Preparing metadata (pyproject.toml): started
Preparing metadata (pyproject.toml): finished with status 'done'
[omitted]
Collecting flax==0.6.7 (from -r big_vision/requirements.txt (line 4))
Downloading flax-0.6.7-py3-none-any.whl (214 kB)
ββββββββββββββββββββββββββββββββββββββ 214.2/214.2 kB 28.6 MB/s eta 0:00:00
[omitted]
Building wheels for collected packages: flaxformer, optax, panopticapi, ml-collections, promise
Building wheel for flaxformer (pyproject.toml): started
Building wheel for flaxformer (pyproject.toml): finished with status 'done'
Created wheel for flaxformer: filename=flaxformer-0.8.1-py3-none-any.whl size=321948 sha256=df38d4209289e8a71a245b56f95490ec0ce9c2bbfaa164fd00d1b7e2f80b5869
[omitted]
Successfully installed MarkupSafe-2.1.3 Pillow-10.0.0 PyYAML-6.0 absl-py-1.4.0 aqtp-0.1.1 array-record-0.4.0 astunparse-1.6.3 cached_property-1.5.2 cachetools-5.3.1 certifi-2023.5.7 charset-normalizer-3.1.0 chex-0.1.7 click-8.1.3 cloudpickle-2.2.1 clu-0.0.8 contextlib2-21.6.0 dacite-1.8.1 decorator-5.1.1 dm-tree-0.1.8 einops-0.6.0 etils-1.3.0 flatbuffers-23.5.26 flax-0.6.7 flaxformer-0.8.1 gast-0.4.0 google-auth-2.21.0 google-auth-oauthlib-1.0.0 google-pasta-0.2.0 googleapis-common-protos-1.59.1 grpcio-1.56.0 h5py-3.9.0 idna-3.4 importlib-resources-5.12.0 keras-2.12.0 libclang-16.0.0 markdown-3.4.3 markdown-it-py-3.0.0 mdurl-0.1.2 ml-collections-0.1.1 msgpack-1.0.5 nest_asyncio-1.5.6 numpy-1.23.5 oauthlib-3.2.2 optax-0.1.5 orbax-0.1.7 overrides-7.3.1 packaging-23.1 panopticapi-0.1 promise-2.3 protobuf-4.23.3 psutil-5.9.5 pyasn1-0.5.0 pyasn1-modules-0.3.0 pygments-2.15.1 requests-2.31.0 requests-oauthlib-1.3.1 rich-13.4.2 rsa-4.9 six-1.16.0 tensorboard-2.12.3 tensorboard-data-server-0.7.1 tensorflow-2.12.0 tensorflow-addons-0.19.0 tensorflow-datasets-4.9.2 tensorflow-estimator-2.12.0 tensorflow-gan-2.1.0 tensorflow-hub-0.13.0 tensorflow-io-gcs-filesystem-0.32.0 tensorflow-metadata-1.13.1 tensorflow-probability-0.20.1 tensorflow-text-2.12.0 tensorstore-0.1.40 termcolor-2.3.0 tfds-nightly-4.8.3.dev202303250044 toml-0.10.2 toolz-0.12.0 tqdm-4.65.0 typeguard-4.0.0 typing-extensions-4.7.1 urllib3-1.26.16 werkzeug-2.3.6 wheel-0.40.0 wrapt-1.14.1
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/yuyang/big_vision/train.py", line 28, in <module>
import big_vision.evaluators.common as eval_common
File "/home/yuyang/big_vision/evaluators/common.py", line 22, in <module>
import flax
File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/__init__.py", line 18, in <module>
from .configurations import (
File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 93, in <module>
flax_filter_frames = define_bool_state(
File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 42, in define_bool_state
return jax_config.define_bool_state('flax_' + name, default, help)
AttributeError: module 'jax.config' has no attribute 'define_bool_state'
```
### Steps to reproduce:
1. Check out `big_vision` locally
```
git@github.com:google-research/big_vision.git
```
2. Create a TPU node
```
gcloud compute tpus tpu-vm create $VM_NAME --zone=$ZONE --accelerator-type=v3-8 --version=tpu-vm-base
```
3. Upload `big_vision` to the TPU and start training
```
gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $VM_NAME: --zone=$ZONE --worker=all
gcloud compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir gs://$BUCKET_NAME/workdirs/`date '+%m-%d_%H%M'`"
```
Normally, if your VRAM is enough, the following code should work, but Iβve searched and it may be a trickier error than I thought.
It means that Diffusers might not be the only cause.
Iβm starting to think that the problem is a pattern of a mis-installed CUDA library and CUDA-compatible torch, or some other library that is doing something wrong.
For example, do SD1.5 and SDXL models work with the same code?
If you just replace the model name part of the code below, it should work in general.
pip install -U diffusers
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
).images[0]
image.save("flux-schnell.png")
gdlm
September 8, 2024, 3:47pm
5
Indeed it seems to be an issue with Flax .
If so, there are three possibilities.
Flux model size is too large and not enough RAM or VRAM.
There is an environment-dependent bug in Diffusersβ FluxPipeline
Outdated version of Diffusers (uninstalling and then installing Diffusers may fix this)
If it is a lack of VRAM, the following NF4 model may solve the problem.
nielsr
September 9, 2024, 7:58am
7
Hi,
Could you remove Jax from your environment? pip uninstall jax
. Alternatively, you could try upgrading jax, pip install --upgrade jax
.
The code snippet only requires PyTorch and Diffusers to be installed.