Iβm running inference on a custom VLM derived model. Inference works fine when using the weights in their bfloat16 precision. However, when I try defining a BitsandBytes config, I receive errors that I suspect is due to conflicts between BitsandBytes and Accelerate, where Accelerate and BitsandBytes are both trying to set the compute device and hence generating the following stack trace.
Traceback (most recent call last):
File "/home/tyr/RobotAI/openvla/scripts/extern/verify_prismatic.py", line 147, in <module>
verify_prismatic()
File "/home/tyr/miniforge3/envs/openvla/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/home/tyr/RobotAI/openvla/scripts/extern/verify_prismatic.py", line 97, in verify_prismatic
vlm = AutoModelForVision2Seq.from_pretrained(
File "/home/tyr/miniforge3/envs/openvla/lib/python3.10/site-packages/transformers/models/auto/auto_factory.py", line 563, in from_pretrained
return model_class.from_pretrained(
File "/home/tyr/miniforge3/envs/openvla/lib/python3.10/site-packages/transformers/modeling_utils.py", line 3735, in from_pretrained
dispatch_model(model, **device_map_kwargs)
File "/home/tyr/miniforge3/envs/openvla/lib/python3.10/site-packages/accelerate/big_modeling.py", line 499, in dispatch_model
model.to(device)
File "/home/tyr/miniforge3/envs/openvla/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2670, in to
raise ValueError(
ValueError: `.to` is not supported for `4-bit` or `8-bit` bitsandbytes models. Please use the model as it is, since the model has already been set to the correct devices and casted to the correct `dtype`.
This is the code that generated the above stack trace:
vlm = AutoModelForVision2Seq.from_pretrained(
MODEL_PATH,
attn_implementation="flash_attention_2",
torch_dtype=torch.float16,
quantization_config=BitsAndBytesConfig(load_in_4bit=True),
low_cpu_mem_usage=True,
trust_remote_code=True,
)
Iβve checked that the model is not being moved with a .to() within my code, Iβve tried adding device_map=None, tried setting torch_dtype=auto, but none of them resolve the issue.
Has anyone encountered this error before or have some suggestions about what might be going wrong? Thanks!