Deepspeed ZeRO-3 flattens convolution that causes runtime error

Hi,

I’m working on fine-tuning a multimodal LLM to video datasets and due to the large-scale nature of the model and huge datasets, I decided to use deepspeed for GPU memory efficiency, with ZeRO-3. (Previously, I used FSDP but it doesn’t solve the CUDA OOM issue.)

To start with, the model I’m trying to fine-tune is as follows:

class CambrianLlamaModel(CambrianMetaModel, LlamaModel):
    config_class = CambrianConfig

    def __init__(self, config: LlamaConfig):
        super(CambrianLlamaModel, self).__init__(config)
    
    # more code...

class CambrianLlamaForCausalLM(LlamaForCausalLM, CambrianMetaForCausalLM):
    config_class = CambrianConfig

    def __init__(self, config):
        super(LlamaForCausalLM, self).__init__(config)

        self.model = CambrianLlamaModel(config)
        self.pretraining_tp = config.pretraining_tp
        self.vocab_size = config.vocab_size
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

    # more code ...

model = CambrianLlamaForCausalLM.from_pretrained(
                    model_args.input_model_filename,
                    **bnb_model_from_pretrained_args,
                )

In the fine-tuning script, I specify a LLaVATrainer instance:

trainer = LLaVATrainer(
        model=model,
        tokenizer=tokenizer,
        args=training_args,
        callbacks=callbacks,
        deepspeed=training_args.deepspeed,
        **data_module,
    )
trainer.train()

And my deepspeed config is as follows:

{
  "compute_environment": "LOCAL_MACHINE",
  "debug": false,
  "deepspeed_config": {
    "gradient_accumulation_steps": 1,
    "gradient_clipping": 8.0,
    "offload_optimizer_device": "cpu",
    "offload_param_device": "cpu",
    "zero3_init_flag": true,
    "zero3_save_16bit_model": true,
    "zero_stage": 3
  },
  "zero_optimization": {
    "stage": 3,
    "offload_optimizer": {
      "device": "cpu"
    },
    "offload_param": {
      "device": "cpu"
    },
    "overlap_comm": true,
    "contiguous_gradients": true,
    "sub_group_size": 1e9,
    "reduce_bucket_size": "auto",
    "stage3_prefetch_bucket_size": 8493465,
    "stage3_param_persistence_threshold": "auto",
    "stage3_max_live_parameters": 1e9,
    "stage3_max_reuse_distance": 1e9,
    "stage3_gather_16bit_weights_on_model_save": true
  },
  "train_micro_batch_size_per_gpu": 1,
  "distributed_type": "DEEPSPEED",
  "downcast_bf16": "no",
  "dynamo_config": {
    "dynamo_backend": "INDUCTOR"
  },
  "enable_cpu_affinity": false,
  "machine_rank": 0,
  "main_training_function": "main",
  "mixed_precision": "bf16",
  "num_machines": 1,
  "num_processes": 3,
  "rdzv_backend": "static",
  "same_network": true,
  "tpu_env": [],
  "tpu_use_cluster": false,
  "tpu_use_sudo": false,
  "use_cpu": false
}

During the fine-tuning process, there’s an error related to the dimension of weights in a conv layer:

Traceback
...
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/transformers/models/dinov2/modeling_dinov2.py", line 162, in forward                                                                  result = forward_call(*args, **kwargs)
  File "user/LongVidLLaMA/./longvu/language_model/cambrian_llama.py", line 304, in forward                                                                                                                   
embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2)
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl

    ) = self.prepare_inputs_labels_for_multimodal(
  File "user/LongVidLLaMA/./longvu/cambrian_arch.py", line 830, in prepare_inputs_labels_for_multimodal                                                                                                      
return self._call_impl(*args, **kwargs)
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                                 
image_aux_features_dino = self.encode_images(
  File "user/LongVidLLaMA/./longvu/cambrian_arch.py", line 609, in encode_images                                                                                                                             
image_aux_features_chunk = vision_tower_aux(chunk)
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl                                                                         
return forward_call(*args, **kwargs)
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 460, in forward                                                                                       
return self._conv_forward(input, self.weight, self.bias)
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward                                                                                 
return self._call_impl(*args, **kwargs)
  File "user/Python-3.10.12/thesis_longvu/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl                                                                                 
return F.conv2d(input, weight, bias, self.stride,
RuntimeError: weight should have at least three dimensions

I’ve done some initial search and found out that this error could stem from ZeRO stage 3 where parameters of the model are also sharded across GPUs. More particularly, the library seems to partition the model parameters into 1D (or 2D) buffers in order to save memory. However, for certain modules such as convolution layers the original shape metadata may not be preserved or automatically restored. In my DINOV2 module, the patch embedding is implemented with a convolution layer that expects its weight to be in 3 or 4 dimensions, but instead, it is coming in as a flattened tensor (e.g. only 2 dimensions). This mismatch causes the error from torch.nn.functional.conv2d (which requires a weight tensor with at least 3 dimensions).

When I tried wrapping the model loading code with deepspeed.zero.Init(enabled=False):

class DinoVisionTower(BaseVisionTower):
...
    def load_model(self, device_map=None):

        # self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
        with deepspeed.zero.Init(enabled=False):
            self.vision_tower = Dinov2Model.from_pretrained(self.vision_tower_name)
...

The error above still persists.

Hence, I would love to hear your experience if you’ve encountered this scenario before when using deepspeed ZeRO-3.

Thanks in advance.

2 Likes