How can I set `max_memory` parameter while loading Quantized model with Model Pipeline class?

What do I want?

  • I want to load and infer a model on specific devices using device_map='auto' to leverage the Accelerate library’s efficient memory management.

What problem did I encounter?

  • I encountered an issue when trying to load a quantized model with the max_memory parameter.

What had I tried?

from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
import torch
import psutil

from typing import Union, List


def build_model(model_path: str,
                gpu_ids: Union[int, List[int]] = 0,
                **kwargs
                ):
    
    if isinstance(gpu_ids, int):
        gpu_ids = [gpu_ids]
    
    max_memory = {}
    for gpu_id in gpu_ids:
        max_memory[gpu_id] = f"{int(torch.cuda.get_device_properties(gpu_id).total_memory * 0.8 / 1024 ** 3)}GiB"

    else:
        max_memory['cpu'] = f"{int(psutil.virtual_memory().total * 0.8 / 1024 ** 3)}GiB"

    
    config = AutoConfig.from_pretrained(model_path)

    is_quantized = "quantization_config" in config

    if not is_quantized:
        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path,
                                                                device_map = 'auto',
                                                                max_memory=max_memory,
                                                                **kwargs)
    
    else:

        config.quantization_config['llm_int8_enable_fp32_cpu_offload'] = True
        # print(config)

        model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path,
                                                                   config = config,
                                                                   device_map = 'auto',
                                                                   max_memory=max_memory,
                                                                   **kwargs)
        

    return model

When I load a regular model like this:

qwen25vl_model = build_model("Qwen/Qwen2.5-VL-7B-Instruct",
                             gpu_ids = [0],
                             
                             torch_dtype = torch.bfloat16,
                             attn_implementation = 'flash_attention_2')


# Fetching 5 files: 100%|██████████| 5/5 [02:27<00:00, 29.48s/it] 
# Loading checkpoint shards: 100%|██████████| 5/5 [00:29<00:00,  5.90s/it]
# Some parameters are on the meta device because they were offloaded to the cpu.

It works well with the memory limit I set.

However, when I try to load the quantized model:

unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit

qwen25vl_bnb_4bit_model = build_model("unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit",
# qwen25vl_model = build_model("Qwen/Qwen2.5-VL-7B-Instruct",
                             gpu_ids = [0],
                             
                             torch_dtype = torch.bfloat16,
                             attn_implementation = 'flash_attention_2')

I get a NotImplementedError as shown below:

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/transformers/modeling_utils.py:272, in restore_default_torch_dtype.<locals>._wrapper(*args, **kwargs)
    270 old_dtype = torch.get_default_dtype()
    271 try:
--> 272     return func(*args, **kwargs)
    273 finally:
    274     torch.set_default_dtype(old_dtype)

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/transformers/modeling_utils.py:4519, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
   4516         device_map_kwargs["offload_buffers"] = True
   4518     if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
-> 4519         dispatch_model(model, **device_map_kwargs)
   4521 if hf_quantizer is not None:
   4522     hf_quantizer.postprocess_model(model, config=config)

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/big_modeling.py:423, in dispatch_model(model, device_map, main_device, state_dict, offload_dir, offload_index, offload_buffers, skip_keys, preload_module_classes, force_hooks)
    418         tied_params_map[data_ptr] = {}
    420         # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
    421         # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
--> 423 attach_align_device_hook_on_blocks(
    424     model,
    425     execution_device=execution_device,
    426     offload=offload,
    427     offload_buffers=offload_buffers,
    428     weights_map=weights_map,
    429     skip_keys=skip_keys,
    430     preload_module_classes=preload_module_classes,
    431     tied_params_map=tied_params_map,
    432 )
    434 # warn if there is any params on the meta device
    435 offloaded_devices_str = " and ".join(
    436     [device for device in set(device_map.values()) if device in ("cpu", "disk")]
    437 )

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:678, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
    676 for child_name, child in module.named_children():
    677     child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
--> 678     attach_align_device_hook_on_blocks(
    679         child,
    680         execution_device=execution_device,
    681         offload=offload,
    682         weights_map=weights_map,
    683         offload_buffers=offload_buffers,
    684         module_name=child_name,
    685         preload_module_classes=preload_module_classes,
    686         skip_keys=skip_keys,
    687         tied_params_map=tied_params_map,
    688     )

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:678, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
    676 for child_name, child in module.named_children():
    677     child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
--> 678     attach_align_device_hook_on_blocks(
    679         child,
    680         execution_device=execution_device,
    681         offload=offload,
    682         weights_map=weights_map,
    683         offload_buffers=offload_buffers,
    684         module_name=child_name,
    685         preload_module_classes=preload_module_classes,
    686         skip_keys=skip_keys,
    687         tied_params_map=tied_params_map,
    688     )

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:678, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
    676 for child_name, child in module.named_children():
    677     child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
--> 678     attach_align_device_hook_on_blocks(
    679         child,
    680         execution_device=execution_device,
    681         offload=offload,
    682         weights_map=weights_map,
    683         offload_buffers=offload_buffers,
    684         module_name=child_name,
    685         preload_module_classes=preload_module_classes,
    686         skip_keys=skip_keys,
    687         tied_params_map=tied_params_map,
    688     )

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:660, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
    653         hook = AlignDevicesHook(
    654             execution_device=execution_device[module_name],
    655             io_same_device=(module_name == ""),
    656             skip_keys=skip_keys,
    657             tied_params_map=tied_params_map,
    658         )
    659         add_hook_to_module(module, hook)
--> 660     attach_execution_device_hook(
    661         module,
    662         execution_device[module_name],
    663         preload_module_classes=preload_module_classes,
    664         skip_keys=skip_keys,
    665         tied_params_map=tied_params_map,
    666     )
    667 elif module_name == "":
    668     hook = AlignDevicesHook(
    669         execution_device=execution_device.get(""),
    670         io_same_device=True,
    671         skip_keys=skip_keys,
    672         tied_params_map=tied_params_map,
    673     )

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:453, in attach_execution_device_hook(module, execution_device, skip_keys, preload_module_classes, tied_params_map)
    450     return
    452 for child in module.children():
--> 453     attach_execution_device_hook(
    454         child,
    455         execution_device,
    456         skip_keys=skip_keys,
    457         preload_module_classes=preload_module_classes,
    458         tied_params_map=tied_params_map,
    459     )

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:442, in attach_execution_device_hook(module, execution_device, skip_keys, preload_module_classes, tied_params_map)
    414 def attach_execution_device_hook(
    415     module: torch.nn.Module,
    416     execution_device: Union[int, str, torch.device],
   (...)    419     tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
    420 ):
    421     """
    422     Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
    423     execution device
   (...)    440             instead of duplicating memory.
    441     """
--> 442     if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
    443         add_hook_to_module(
    444             module,
    445             AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
    446         )
    448     # Break the recursion if we get to a preload module.

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1916, in Module.state_dict(self, destination, prefix, keep_vars, *args)
   1914 for name, module in self._modules.items():
   1915     if module is not None:
-> 1916         module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
   1917 for hook in self._state_dict_hooks.values():
   1918     hook_result = hook(self, destination, prefix, local_metadata)

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1913, in Module.state_dict(self, destination, prefix, keep_vars, *args)
   1911 for hook in self._state_dict_pre_hooks.values():
   1912     hook(self, prefix, keep_vars)
-> 1913 self._save_to_state_dict(destination, prefix, keep_vars)
   1914 for name, module in self._modules.items():
   1915     if module is not None:

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:464, in Linear4bit._save_to_state_dict(self, destination, prefix, keep_vars)
    461 super()._save_to_state_dict(destination, prefix, keep_vars)  # saving weight and bias
    463 if getattr(self.weight, "quant_state", None) is not None:
--> 464     for k, v in self.weight.quant_state.as_dict(packed=True).items():
    465         destination[prefix + "weight." + k] = v if keep_vars else v.detach()

File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/bitsandbytes/functional.py:810, in QuantState.as_dict(self, packed)
    795 qs_dict = {
    796     "quant_type": self.quant_type,
    797     "absmax": self.absmax,
   (...)    801     "shape": tuple(self.shape),
    802 }
    803 if self.nested:
    804     qs_dict.update(
    805         {
    806             "nested_absmax": self.state2.absmax,
    807             "nested_blocksize": self.state2.blocksize,
    808             "nested_quant_map": self.state2.code.clone(),  # un-shared to avoid restoring it after shared tensors are removed by safetensors
    809             "nested_dtype": str(self.state2.dtype).strip("torch."),
--> 810             "nested_offset": self.offset.item(),
    811         },
    812     )
    813 if not packed:
    814     return qs_dict

NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

I’d really appreciate any help on this🙇‍♂️
Thank you for taking the time to read my question. I hope you have a great day!

1 Like

I was able to reproduce the same error here. And it doesn’t happen with “unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit”…
Looking at the Community section, it seems that there is a problem, although the symptoms are different, and in some cases it seems that the problem can be solved by downgrading Transformers.

1 Like

Thanks for your kind response! :man_bowing:
Unfortunately, downgrading the transformers version to 4.49.0 didn’t work for my case😭
It still shows the same error as below:

File "~/anaconda3/envs/sample_env/lib/python3.11/site-packages/bitsandbytes/nn/modules.py", line 464, in _save_to_state_dict
    for k, v in self.weight.quant_state.as_dict(packed=True).items():
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "~/anaconda3/envs/sample_env/lib/python3.11/site-packages/bitsandbytes/functional.py", line 810, in as_dict
    "nested_offset": self.offset.item(),
                     ^^^^^^^^^^^^^^^^^^
NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit

However, as you suggested, it works properly with unsloth/Qwen2.5-VL-3B-Instruct-unsloth-bnb-4bit!
It might be related to the weights.

Thank you again for your kind reply, and I hope you have a great day! :man_bowing:


PS: If anyone else encounters a problem similar to mine, I hope this information helps, so I’m sharing my environment below.

  • Python 3.11.11
  • CUDA 12.1
  • torch==2.3.1+cu121
  • torchvision==0.18.1+cu121
  • accelerate==1.5.2
  • bitsandbytes==0.45.3
  • flash-attn==2.7.3
  • transformers==4.49.0
1 Like