Failed to load model when using hf intergrated deepspeed, but no error when separate model loading and deepspeed initialization

I am trying to run the project Muffin github using deepspeed training.
However, I encounter some difficulties when loading the model.
When running the command like

export CUDA_VISIBLE_DEVICES=0
deepspeed ./muffin/train/debug.py \
--model_name_or_path ./RLHF-V_SFT_weight \
--deepspeed ./script/train/zero3.json \
--output_dir debug

to just load the model, an error will be raised. The python script is:

import torch
from muffin import Beit3LlavaLlamaForCausalLM
import transformers
from typing import Optional
from dataclasses import dataclass, field

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="facebook/opt-125m")

def load():
    parser = transformers.HfArgumentParser(
        (ModelArguments, transformers.TrainingArguments))
    model_args,training_args = parser.parse_args_into_dataclasses()

    model = Beit3LlavaLlamaForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=torch.float16
    )

if __name__ == "__main__":
    load()

Note that the BEIT3 is loaded by: timm.create_model

Deepspeed configuration:

{
    "fp16": {
        "enabled": true,
        "loss_scale": 0,
        "loss_scale_window": 1000,
        "initial_scale_power": 16,
        "hysteresis": 2,
        "min_loss_scale": 1
    },
    "bf16": {
        "enabled": false
    },
    "train_micro_batch_size_per_gpu": 8,
    "train_batch_size": 8,
    "gradient_accumulation_steps": 1,
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": true,
        "contiguous_gradients": true,
        "sub_group_size": 1e9,
        "reduce_bucket_size": "auto",
        "stage3_prefetch_bucket_size": "auto",
        "stage3_param_persistence_threshold": "auto",
        "stage3_max_live_parameters": 1e9,
        "stage3_max_reuse_distance": 1e9,
        "stage3_gather_16bit_weights_on_model_save": true
    }
}

However, if not using the the huggingface integrated deepspeed (no HfArgumentParser to pass --deepspeed parameter), i.e. running the code below, there will be no errors.

    import deepspeed, json
    model = Beit3LlavaLlamaForCausalLM.from_pretrained(
        model_name_or_path,
        torch_dtype=torch.float16
    )

    model_engine, optimizer, _, _ = deepspeed.initialize(model=model,
                                                         model_parameters=model.parameters(),
                                                         config_params=json.load(open('./script/train/zero3.json')))

Error message:

[2024-04-25 04:12:22,346] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-25 04:12:24,796] [INFO] [launch.py:145:main] WORLD INFO DICT: {'localhost': [0]}
[2024-04-25 04:12:24,796] [INFO] [launch.py:151:main] nnodes=1, num_local_procs=1, node_rank=0
[2024-04-25 04:12:24,796] [INFO] [launch.py:162:main] global_rank_mapping=defaultdict(<class 'list'>, {'localhost': [0]})
[2024-04-25 04:12:24,796] [INFO] [launch.py:163:main] dist_world_size=1
[2024-04-25 04:12:24,797] [INFO] [launch.py:165:main] Setting CUDA_VISIBLE_DEVICES=0
[2024-04-25 04:12:27,946] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)
[2024-04-25 04:12:28,765] [WARNING] [comm.py:152:init_deepspeed_backend] NCCL backend in DeepSpeed not yet implemented
[2024-04-25 04:12:28,765] [INFO] [comm.py:594:init_distributed] cdb=None
[2024-04-25 04:12:28,765] [INFO] [comm.py:625:init_distributed] Initializing TorchBackend in DeepSpeed with backend nccl
/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torch/nn/init.py:405: UserWarning: Initializing zero-element tensors is a no-op
  warnings.warn("Initializing zero-element tensors is a no-op")
[2024-04-25 04:12:30,365] [INFO] [partition_parameters.py:453:__exit__] finished initializing model with 12.92B parameters
Traceback (most recent call last):
  File "/root/data/yflu/muffin/./muffin/train/debug.py", line 22, in <module>
    load()
  File "/root/data/yflu/muffin/./muffin/train/debug.py", line 16, in load
    model = Beit3LlavaLlamaForCausalLM.from_pretrained(
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2959, in from_pretrained
    model = cls(config, *model_args, **model_kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/data/yflu/muffin/muffin/model/muffin.py", line 311, in __init__
    self.model = Beit3LlavaLlamaModel(config, mm_vision_tower=mm_vision_tower)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/data/yflu/muffin/muffin/model/muffin.py", line 153, in __init__
    self.vision_tower = timm.create_model(mm_vision_tower)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/timm/models/factory.py", line 81, in create_model
    model = create_fn(pretrained=pretrained, **kwargs)
  File "/root/data/yflu/muffin/muffin/model/beit3.py", line 135, in beit3_large_patch16_672
    model = BEiT3Wrapper(args, **kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/data/yflu/muffin/muffin/model/beit3.py", line 51, in __init__
    self.beit3 = BEiT3(args)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/model/BEiT3.py", line 40, in __init__
    self.encoder = Encoder(
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/architecture/encoder.py", line 209, in __init__
    self.build_encoder_layer(
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/architecture/encoder.py", line 296, in build_encoder_layer
    layer = EncoderLayer(
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/architecture/encoder.py", line 30, in __init__
    self.self_attn = self.build_self_attention(self.embed_dim, args)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/architecture/encoder.py", line 103, in build_self_attention
    return MultiheadAttention(
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/component/multihead_attention.py", line 40, in __init__
    self.k_proj = MultiwayWrapper(args, nn.Linear(embed_dim, embed_dim, bias=True))
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/component/multiway_network.py", line 12, in MultiwayWrapper
    return MultiwayNetwork(module, dim=dim)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/deepspeed/runtime/zero/partition_parameters.py", line 385, in wrapper
    f(module, *args, **kwargs)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torchscale/component/multiway_network.py", line 30, in __init__
    self.B.reset_parameters()
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 109, in reset_parameters
    fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
  File "/root/anaconda3/envs/muffin/lib/python3.10/site-packages/torch/nn/init.py", line 287, in _calculate_fan_in_and_fan_out
    raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
ValueError: Fan in and fan out can not be computed for tensor with fewer than 2 dimensions

It seems that multihead attention module BEIT3 cannot be initialized. The weight of linear layer self.B in module MultiwayNetwork is empty.

My environment:
V100 with CUDA 11.7
deepspeed 0.9.5
torch 2.0.1
transformers 4.33.3