What do I want?
- I want to load and infer a model on specific devices using
device_map='auto'
to leverage the Accelerate library’s efficient memory management.
What problem did I encounter?
- I encountered an issue when trying to load a quantized model with the
max_memory
parameter.
What had I tried?
from transformers import Qwen2_5_VLForConditionalGeneration, AutoConfig
import torch
import psutil
from typing import Union, List
def build_model(model_path: str,
gpu_ids: Union[int, List[int]] = 0,
**kwargs
):
if isinstance(gpu_ids, int):
gpu_ids = [gpu_ids]
max_memory = {}
for gpu_id in gpu_ids:
max_memory[gpu_id] = f"{int(torch.cuda.get_device_properties(gpu_id).total_memory * 0.8 / 1024 ** 3)}GiB"
else:
max_memory['cpu'] = f"{int(psutil.virtual_memory().total * 0.8 / 1024 ** 3)}GiB"
config = AutoConfig.from_pretrained(model_path)
is_quantized = "quantization_config" in config
if not is_quantized:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path,
device_map = 'auto',
max_memory=max_memory,
**kwargs)
else:
config.quantization_config['llm_int8_enable_fp32_cpu_offload'] = True
# print(config)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_path,
config = config,
device_map = 'auto',
max_memory=max_memory,
**kwargs)
return model
When I load a regular model like this:
qwen25vl_model = build_model("Qwen/Qwen2.5-VL-7B-Instruct",
gpu_ids = [0],
torch_dtype = torch.bfloat16,
attn_implementation = 'flash_attention_2')
# Fetching 5 files: 100%|██████████| 5/5 [02:27<00:00, 29.48s/it]
# Loading checkpoint shards: 100%|██████████| 5/5 [00:29<00:00, 5.90s/it]
# Some parameters are on the meta device because they were offloaded to the cpu.
It works well with the memory limit I set.
However, when I try to load the quantized model:
unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit
qwen25vl_bnb_4bit_model = build_model("unsloth/Qwen2.5-VL-7B-Instruct-unsloth-bnb-4bit",
# qwen25vl_model = build_model("Qwen/Qwen2.5-VL-7B-Instruct",
gpu_ids = [0],
torch_dtype = torch.bfloat16,
attn_implementation = 'flash_attention_2')
I get a NotImplementedError
as shown below:
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/transformers/modeling_utils.py:272, in restore_default_torch_dtype.<locals>._wrapper(*args, **kwargs)
270 old_dtype = torch.get_default_dtype()
271 try:
--> 272 return func(*args, **kwargs)
273 finally:
274 torch.set_default_dtype(old_dtype)
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/transformers/modeling_utils.py:4519, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, weights_only, *model_args, **kwargs)
4516 device_map_kwargs["offload_buffers"] = True
4518 if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
-> 4519 dispatch_model(model, **device_map_kwargs)
4521 if hf_quantizer is not None:
4522 hf_quantizer.postprocess_model(model, config=config)
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/big_modeling.py:423, in dispatch_model(model, device_map, main_device, state_dict, offload_dir, offload_index, offload_buffers, skip_keys, preload_module_classes, force_hooks)
418 tied_params_map[data_ptr] = {}
420 # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
421 # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
--> 423 attach_align_device_hook_on_blocks(
424 model,
425 execution_device=execution_device,
426 offload=offload,
427 offload_buffers=offload_buffers,
428 weights_map=weights_map,
429 skip_keys=skip_keys,
430 preload_module_classes=preload_module_classes,
431 tied_params_map=tied_params_map,
432 )
434 # warn if there is any params on the meta device
435 offloaded_devices_str = " and ".join(
436 [device for device in set(device_map.values()) if device in ("cpu", "disk")]
437 )
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:678, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
676 for child_name, child in module.named_children():
677 child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
--> 678 attach_align_device_hook_on_blocks(
679 child,
680 execution_device=execution_device,
681 offload=offload,
682 weights_map=weights_map,
683 offload_buffers=offload_buffers,
684 module_name=child_name,
685 preload_module_classes=preload_module_classes,
686 skip_keys=skip_keys,
687 tied_params_map=tied_params_map,
688 )
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:678, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
676 for child_name, child in module.named_children():
677 child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
--> 678 attach_align_device_hook_on_blocks(
679 child,
680 execution_device=execution_device,
681 offload=offload,
682 weights_map=weights_map,
683 offload_buffers=offload_buffers,
684 module_name=child_name,
685 preload_module_classes=preload_module_classes,
686 skip_keys=skip_keys,
687 tied_params_map=tied_params_map,
688 )
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:678, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
676 for child_name, child in module.named_children():
677 child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
--> 678 attach_align_device_hook_on_blocks(
679 child,
680 execution_device=execution_device,
681 offload=offload,
682 weights_map=weights_map,
683 offload_buffers=offload_buffers,
684 module_name=child_name,
685 preload_module_classes=preload_module_classes,
686 skip_keys=skip_keys,
687 tied_params_map=tied_params_map,
688 )
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:660, in attach_align_device_hook_on_blocks(module, execution_device, offload, weights_map, offload_buffers, module_name, skip_keys, preload_module_classes, tied_params_map)
653 hook = AlignDevicesHook(
654 execution_device=execution_device[module_name],
655 io_same_device=(module_name == ""),
656 skip_keys=skip_keys,
657 tied_params_map=tied_params_map,
658 )
659 add_hook_to_module(module, hook)
--> 660 attach_execution_device_hook(
661 module,
662 execution_device[module_name],
663 preload_module_classes=preload_module_classes,
664 skip_keys=skip_keys,
665 tied_params_map=tied_params_map,
666 )
667 elif module_name == "":
668 hook = AlignDevicesHook(
669 execution_device=execution_device.get(""),
670 io_same_device=True,
671 skip_keys=skip_keys,
672 tied_params_map=tied_params_map,
673 )
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:453, in attach_execution_device_hook(module, execution_device, skip_keys, preload_module_classes, tied_params_map)
450 return
452 for child in module.children():
--> 453 attach_execution_device_hook(
454 child,
455 execution_device,
456 skip_keys=skip_keys,
457 preload_module_classes=preload_module_classes,
458 tied_params_map=tied_params_map,
459 )
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/accelerate/hooks.py:442, in attach_execution_device_hook(module, execution_device, skip_keys, preload_module_classes, tied_params_map)
414 def attach_execution_device_hook(
415 module: torch.nn.Module,
416 execution_device: Union[int, str, torch.device],
(...) 419 tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None,
420 ):
421 """
422 Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
423 execution device
(...) 440 instead of duplicating memory.
441 """
--> 442 if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
443 add_hook_to_module(
444 module,
445 AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
446 )
448 # Break the recursion if we get to a preload module.
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1916, in Module.state_dict(self, destination, prefix, keep_vars, *args)
1914 for name, module in self._modules.items():
1915 if module is not None:
-> 1916 module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
1917 for hook in self._state_dict_hooks.values():
1918 hook_result = hook(self, destination, prefix, local_metadata)
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/torch/nn/modules/module.py:1913, in Module.state_dict(self, destination, prefix, keep_vars, *args)
1911 for hook in self._state_dict_pre_hooks.values():
1912 hook(self, prefix, keep_vars)
-> 1913 self._save_to_state_dict(destination, prefix, keep_vars)
1914 for name, module in self._modules.items():
1915 if module is not None:
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/bitsandbytes/nn/modules.py:464, in Linear4bit._save_to_state_dict(self, destination, prefix, keep_vars)
461 super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
463 if getattr(self.weight, "quant_state", None) is not None:
--> 464 for k, v in self.weight.quant_state.as_dict(packed=True).items():
465 destination[prefix + "weight." + k] = v if keep_vars else v.detach()
File ~/anaconda3/envs/sample_env/lib/python3.11/site-packages/bitsandbytes/functional.py:810, in QuantState.as_dict(self, packed)
795 qs_dict = {
796 "quant_type": self.quant_type,
797 "absmax": self.absmax,
(...) 801 "shape": tuple(self.shape),
802 }
803 if self.nested:
804 qs_dict.update(
805 {
806 "nested_absmax": self.state2.absmax,
807 "nested_blocksize": self.state2.blocksize,
808 "nested_quant_map": self.state2.code.clone(), # un-shared to avoid restoring it after shared tensors are removed by safetensors
809 "nested_dtype": str(self.state2.dtype).strip("torch."),
--> 810 "nested_offset": self.offset.item(),
811 },
812 )
813 if not packed:
814 return qs_dict
NotImplementedError: aten::_local_scalar_dense: attempted to run this operator with Meta tensors, but there was no abstract impl or Meta kernel registered. You may have run into this message while using an operator with PT2 compilation APIs (torch.compile/torch.export); in order to use this operator with those APIs you'll need to add an abstract impl.Please see the following doc for next steps: https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit
I’d really appreciate any help on this🙇♂️
Thank you for taking the time to read my question. I hope you have a great day!