Deepspeed inference stage 3 + quantization

trying to set deepspeed with inference zero 3 as follows:

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch 
from transformers.integrations import HfDeepSpeedConfig


with torch.no_grad():
    hfds_config = HfDeepSpeedConfig(config_file_or_dict="config.json")
    # Now model is on-the-fly quantized.
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
    model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")

input_text = "Query"
input_ids = tokenizer(input_text, return_tensors="pt")

outputs = model.generate(**input_ids, max_length= 200)
print(tokenizer.decode(outputs[0]))

With config.json the deepspeed config file as follows:

{
    "weight_quantization": {
        "quantized_initialization": {
            "num_bits": 4,
            "group_size": 64,
            "group_dim": 1,
            "symmetric": false
        }
    }, 
    "zero_optimization": {
          "stage": 3,
          "offload_optimizer": {
              "device": "cpu",
              "pin_memory": true
          },
          "offload_param": {
              "device": "cpu",
              "pin_memory": true
          },
          "overlap_comm": true,
          "contiguous_gradients": true,
          "sub_group_size": 1e9,
          "reduce_bucket_size": "auto",
          "stage3_prefetch_bucket_size": "auto",
          "stage3_param_persistence_threshold": "auto",
          "stage3_max_live_parameters": 1e9,
          "stage3_max_reuse_distance": 1e9,
          "stage3_gather_fp16_weights_on_model_save": true
      }, 
      "train_batch_size": 32, 
      "torch_dtype": "float32"
  }

and the error as follows:

[2024-03-08 15:12:19,262] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-03-08 15:12:19,600] [WARNING] [config_utils.py:69:_process_deprecated_field] Config parameter stage3_gather_fp16_weights_on_model_save is deprecated use gather_16bit_weights_on_model_save instead
[2024-03-08 15:12:19,601] [INFO] [comm.py:637:init_distributed] cdb=None
[2024-03-08 15:12:19,602] [INFO] [comm.py:652:init_distributed] Not using the DeepSpeed or dist launchers, attempting to detect MPI environment…
[2024-03-08 15:12:19,624] [INFO] [comm.py:702:mpi_discovery] Discovered MPI settings of world_rank=0, local_rank=0, world_size=1, master_addr=10.0.0.4, master_port=29500
[2024-03-08 15:12:19,626] [INFO] [comm.py:668:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
[2024-03-08 15:12:19,635] [INFO] [partition_parameters.py:559:patch_init_and_builtins] Enable Zero3 engine with INT4 quantization.
[2024-03-08 15:12:22,156] [INFO] [partition_parameters.py:343:exit] finished initializing model - num_params = 165, num_elems = 3.03B
{
“name”: “RuntimeError”,
“message”: “self.size(-1) must be divisible by 2 to view BFloat16 as Float (different element sizes), but got 1”,
“stack”: "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[1], line 10
8 # Now model is on-the-fly quantized.
9 tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
—> 10 model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it")
12 input_text = "Query"
13 input_ids = tokenizer(input_text, return_tensors="pt")

File ~/.local/lib/python3.8/site-packages/transformers/models/auto/auto_factory.py:561, in _BaseAutoModelClass.from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
559 elif type(config) in cls._model_mapping.keys():
560 model_class = _get_model_class(config, cls._model_mapping)
→ 561 return model_class.from_pretrained(
562 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
563 )
564 raise ValueError(
565 f"Unrecognized configuration class {config.class} for this kind of AutoModel: {cls.name}.
"
566 f"Model type should be one of {', '.join(c.name for c in cls._model_mapping.keys())}."
567 )

File ~/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:3502, in PreTrainedModel.from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
3493 if dtype_orig is not None:
3494 torch.set_default_dtype(dtype_orig)
3495 (
3496 model,
3497 missing_keys,
3498 unexpected_keys,
3499 mismatched_keys,
3500 offload_index,
3501 error_msgs,
→ 3502 ) = cls._load_pretrained_model(
3503 model,
3504 state_dict,
3505 loaded_state_dict_keys, # XXX: rename?
3506 resolved_archive_file,
3507 pretrained_model_name_or_path,
3508 ignore_mismatched_sizes=ignore_mismatched_sizes,
3509 sharded_metadata=sharded_metadata,
3510 _fast_init=_fast_init,
3511 low_cpu_mem_usage=low_cpu_mem_usage,
3512 device_map=device_map,
3513 offload_folder=offload_folder,
3514 offload_state_dict=offload_state_dict,
3515 dtype=torch_dtype,
3516 hf_quantizer=hf_quantizer,
3517 keep_in_fp32_modules=keep_in_fp32_modules,
3518 )
3520 # make sure token embedding weights are still tied if needed
3521 model.tie_weights()

File ~/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:3945, in PreTrainedModel._load_pretrained_model(cls, model, state_dict, loaded_keys, resolved_archive_file, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, _fast_init, low_cpu_mem_usage, device_map, offload_folder, offload_state_dict, dtype, hf_quantizer, keep_in_fp32_modules)
3943 error_msgs += new_error_msgs
3944 else:
→ 3945 error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
3947 # force memory release
3948 del state_dict

File ~/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:626, in _load_state_dict_into_model(model_to_load, state_dict, start_prefix)
623 if child is not None:
624 load(child, state_dict, prefix + name + ".")
→ 626 load(model_to_load, state_dict, prefix=start_prefix)
627 # Delete state_dict so it could be collected by GC earlier. Note that state_dict is a copy of the argument, so
628 # it’s safe to delete it.
629 del state_dict

File ~/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:624, in _load_state_dict_into_model..load(module, state_dict, prefix)
622 for name, child in module._modules.items():
623 if child is not None:
→ 624 load(child, state_dict, prefix + name + ".")

File ~/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:624, in _load_state_dict_into_model..load(module, state_dict, prefix)
622 for name, child in module._modules.items():
623 if child is not None:
→ 624 load(child, state_dict, prefix + name + ".")

File ~/.local/lib/python3.8/site-packages/transformers/modeling_utils.py:618, in _load_state_dict_into_model..load(module, state_dict, prefix)
616 with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
617 if torch.distributed.get_rank() == 0:
→ 618 module._load_from_state_dict(*args)
619 else:
620 module._load_from_state_dict(*args)

File ~/.local/lib/python3.8/site-packages/deepspeed/inference/quantization/utils.py:269, in wrap_load_from_state_dict..wrapper(model, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
267 quantized_weight, quant_scale, quant_min = model.weight.quantizer.quantize(state_dict[key])
268 quantized_weight = quantized_weight.view(model.weight.dtype)
→ 269 quant_scale = quant_scale.view(model.weight.dtype)
270 quant_min = quant_min.view(model.weight.dtype)
272 replaced_old_value = state_dict[key]

RuntimeError: self.size(-1) must be divisible by 2 to view BFloat16 as Float (different element sizes), but got 1"
}

I tried setting the group_dim to 2 in the config.json, but this gave the error that the tuple was out of range.

My GPU doesn’t support BFloat type format. It is by default disabled I thought, but specifically set the dtype to fp32 (not that this changes anything). How can I fix this issue? The reason why I would like to use zero stage 3 is to go from a 2B model to a 7B model, offloading to CPU (also to test this out a bit).