opened 01:20PM - 05 Jul 23 UTC
Priority: P2 - no schedule
### System information
- OS Platform and Distribution (e.g., Linux Ubuntu 16.04β¦ ): tpu-vm-base
- Flax, jax, jaxlib versions (obtain with `pip show flax jax jaxlib`: flax: 0.6.7, jax: 0.4.13, jaxlib: 0.4.13
- Python version: 3.8
- GPU/TPU model and memory: tpu v3-8
- CUDA version (if applicable): N/A
### Problem you have encountered:
Hello, I am seeing
```
AttributeError: module 'jax.config' has no attribute 'define_bool_state'
```
when running the `big_vision` library on `tpu-vm-base`.
I saw a similar issue in [another library](https://github.com/google-research/scenic/issues/816) where the issue seems to be resolved by fixing `jax` version to `0.4.9`, but when I attempted that it did not work.
I also tried fixing the versions of all packages in the `requirements.txt` of `big_vision`, i.e.
```
absl-py==1.4.0
clu==0.0.8
einops==0.6.0
flax==0.6.7
git+https://github.com/google/flaxformer
git+https://github.com/deepmind/optax.git
git+https://github.com/akolesnikoff/panopticapi.git@mute
overrides==7.3.1
tensorflow==2.12.0
tfds-nightly==4.8.3.dev202303250044
tensorflow-addons==0.19.0
tensorflow-text==2.12.0
tensorflow-gan==2.1.0
```
at the same time when fixing `jax` to `0.4.9`, but that did not work either.
I had to use a full `requirements.txt` obtained from running `pip freeze` in a local venv created on 2023-03-26 to get the library running on TPU again.
### What you expected to happen:
I was able to run `big_vision` on `tpu-vm-base` on a `v3-8` TPU node without fixing any package versions as late as 2023-05-24.
### Logs, error messages, etc:
```
Installing collected packages: libtpu-nightly, zipp, numpy, scipy, opt-einsum, ml-dtypes, importlib-metadata, jaxlib, jax
Successfully installed importlib-metadata-6.7.0 jax-0.4.13 jaxlib-0.4.13 libtpu-nightly-0.1.dev20230622 ml-dtypes-0.2.0 numpy-1.24.4 opt-einsum-3.3.0 scipy-1.10.1 zipp-3.15.0
Collecting git+https://github.com/google/flaxformer (from -r big_vision/requirements.txt (line 5))
Cloning https://github.com/google/flaxformer to /tmp/pip-req-build-925ai1ze
Running command git clone --filter=blob:none --quiet https://github.com/google/flaxformer /tmp/pip-req-build-925ai1ze
Resolved https://github.com/google/flaxformer to commit 9adaa4467cf17703949b9f537c3566b99de1b416
Installing build dependencies: started
Installing build dependencies: finished with status 'done'
Getting requirements to build wheel: started
Getting requirements to build wheel: finished with status 'done'
Preparing metadata (pyproject.toml): started
Preparing metadata (pyproject.toml): finished with status 'done'
[omitted]
Collecting flax==0.6.7 (from -r big_vision/requirements.txt (line 4))
Downloading flax-0.6.7-py3-none-any.whl (214 kB)
ββββββββββββββββββββββββββββββββββββββ 214.2/214.2 kB 28.6 MB/s eta 0:00:00
[omitted]
Building wheels for collected packages: flaxformer, optax, panopticapi, ml-collections, promise
Building wheel for flaxformer (pyproject.toml): started
Building wheel for flaxformer (pyproject.toml): finished with status 'done'
Created wheel for flaxformer: filename=flaxformer-0.8.1-py3-none-any.whl size=321948 sha256=df38d4209289e8a71a245b56f95490ec0ce9c2bbfaa164fd00d1b7e2f80b5869
[omitted]
Successfully installed MarkupSafe-2.1.3 Pillow-10.0.0 PyYAML-6.0 absl-py-1.4.0 aqtp-0.1.1 array-record-0.4.0 astunparse-1.6.3 cached_property-1.5.2 cachetools-5.3.1 certifi-2023.5.7 charset-normalizer-3.1.0 chex-0.1.7 click-8.1.3 cloudpickle-2.2.1 clu-0.0.8 contextlib2-21.6.0 dacite-1.8.1 decorator-5.1.1 dm-tree-0.1.8 einops-0.6.0 etils-1.3.0 flatbuffers-23.5.26 flax-0.6.7 flaxformer-0.8.1 gast-0.4.0 google-auth-2.21.0 google-auth-oauthlib-1.0.0 google-pasta-0.2.0 googleapis-common-protos-1.59.1 grpcio-1.56.0 h5py-3.9.0 idna-3.4 importlib-resources-5.12.0 keras-2.12.0 libclang-16.0.0 markdown-3.4.3 markdown-it-py-3.0.0 mdurl-0.1.2 ml-collections-0.1.1 msgpack-1.0.5 nest_asyncio-1.5.6 numpy-1.23.5 oauthlib-3.2.2 optax-0.1.5 orbax-0.1.7 overrides-7.3.1 packaging-23.1 panopticapi-0.1 promise-2.3 protobuf-4.23.3 psutil-5.9.5 pyasn1-0.5.0 pyasn1-modules-0.3.0 pygments-2.15.1 requests-2.31.0 requests-oauthlib-1.3.1 rich-13.4.2 rsa-4.9 six-1.16.0 tensorboard-2.12.3 tensorboard-data-server-0.7.1 tensorflow-2.12.0 tensorflow-addons-0.19.0 tensorflow-datasets-4.9.2 tensorflow-estimator-2.12.0 tensorflow-gan-2.1.0 tensorflow-hub-0.13.0 tensorflow-io-gcs-filesystem-0.32.0 tensorflow-metadata-1.13.1 tensorflow-probability-0.20.1 tensorflow-text-2.12.0 tensorstore-0.1.40 termcolor-2.3.0 tfds-nightly-4.8.3.dev202303250044 toml-0.10.2 toolz-0.12.0 tqdm-4.65.0 typeguard-4.0.0 typing-extensions-4.7.1 urllib3-1.26.16 werkzeug-2.3.6 wheel-0.40.0 wrapt-1.14.1
Traceback (most recent call last):
File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/yuyang/big_vision/train.py", line 28, in <module>
import big_vision.evaluators.common as eval_common
File "/home/yuyang/big_vision/evaluators/common.py", line 22, in <module>
import flax
File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/__init__.py", line 18, in <module>
from .configurations import (
File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 93, in <module>
flax_filter_frames = define_bool_state(
File "/home/yuyang/bv_venv/lib/python3.8/site-packages/flax/configurations.py", line 42, in define_bool_state
return jax_config.define_bool_state('flax_' + name, default, help)
AttributeError: module 'jax.config' has no attribute 'define_bool_state'
```
### Steps to reproduce:
1. Check out `big_vision` locally
```
git@github.com:google-research/big_vision.git
```
2. Create a TPU node
```
gcloud compute tpus tpu-vm create $VM_NAME --zone=$ZONE --accelerator-type=v3-8 --version=tpu-vm-base
```
3. Upload `big_vision` to the TPU and start training
```
gcloud compute tpus tpu-vm scp --recurse big_vision/big_vision $VM_NAME: --zone=$ZONE --worker=all
gcloud compute tpus tpu-vm ssh $VM_NAME --zone=$ZONE --worker=all --command "TFDS_DATA_DIR=gs://$BUCKET_NAME/tensorflow_datasets bash big_vision/run_tpu.sh big_vision.train --config big_vision/configs/vit_s16_i1k.py --workdir gs://$BUCKET_NAME/workdirs/`date '+%m-%d_%H%M'`"
```
Normally, if your VRAM is enough, the following code should work, but Iβve searched and it may be a trickier error than I thought.
It means that Diffusers might not be the only cause.
Iβm starting to think that the problem is a pattern of a mis-installed CUDA library and CUDA-compatible torch, or some other library that is doing something wrong.
For example, do SD1.5 and SDXL models work with the same code?
If you just replace the model name part of the code below, it should work in general.
pip install -U diffusers
import torch
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt,
).images[0]
image.save("flux-schnell.png")