Tensor shape mismatch error when doing an allgather in distributed training with FSDP

Hi,

Iā€™m finetuning a multimodal LLM and during this process, I encounter the following error when attempting to save the checkpoint. More particularly, I can save the model normally but when the optimizer states are saved, the following error occurs:

RuntimeError: Detected mismatch between collectives on ranks. Rank 0 is running collective: CollectiveFingerPrint(SequenceNumber=549, OpType=ALLGATHER, TensorShape=[0], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))), but Rank 2 is running collective: CollectiveFingerPrint(SequenceNumber=549, OpType=ALLGATHER, TensorShape=[183971584], TensorDtypes=Float, TensorDeviceTypes=TensorOptions(dtype=float (default), device=cuda, layout=Strided (default), requires_grad=false (default), pinned_memory=false (default), memory_format=(nullopt))).Collectives differ in the following aspects:   Tensor Tensor shapes: 0vs 183971584

This is the full traceback of tensor shape mismatch when saving fsdp optimizer states:

File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2356, in _inner_training_loop
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2807, in _maybe_log_save_evaluate
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 2890, in _save_checkpoint
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/transformers/trainer.py", line 3001, in _save_optimizer_and_scheduler
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/accelerate/utils/fsdp_utils.py", line 185, in save_fsdp_optimizer
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1828, in optim_state_dict
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 1253, in _optim_state_dict_impl
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1396, in _optim_state_dict
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1657, in _gather_orig_param_state
File "/media02/nthuy/miniconda3/envs/thesis_longvu/lib/python3.10/site-packages/torch/distributed/fsdp/_optim_utils.py", line 1593, in _all_gather_optim_state

The packages version is:

torch==2.1.2
numpy==1.26.4
transformers==4.43.1

This is the full details I trace the error in functions listed in the traceback. Itā€™s quite long but it leads to what I believe a possible reason for error above and my question below:

Summary

I began digging into the codebase, starting with torch.distributed.fsdp, to find out the cause as follows:

work = dist.all_gather(
            tensors, local_state, group=fsdp_state.process_group, async_op=True
        )

tensors is a list where certain elements are 0 and others are non-zero, as raised by the error above. When I printed out object_state.tensors:

2025-02-22 17:02:42,065 - root - DEBUG - rank 0, object_state.tensors: {}, name: exp_avg, info: None
2025-02-22 17:02:42,065 - root - DEBUG - rank 1, object_state.tensors: {}, name: exp_avg, info: None
2025-02-22 17:02:42,066 - root - DEBUG - rank 2, object_state.tensors: {'exp_avg': _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32), 'exp_avg_sq': _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32)}, name: exp_avg, info: _PosDimTensorInfo(shape=torch.Size([183971584]), dtype=torch.float32)                                                                                            
2025-02-22 17:02:42,066 - root - DEBUG - rank 3, object_state.tensors: {'exp_avg': _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32), 'exp_avg_sq': _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32)}, name: exp_avg, info: _PosDimTensorInfo(shape=torch.Size([210030848]), dtype=torch.float32)

It can be seen that on the 2 ranks 0 and 1, the tensors are empty. Since these object_state.tensors are gathered in object_list from processes in the process group via:

dist.all_gather_object(object_list, processed_state, group=fsdp_state.process_group)

It seems that the processed_state for ranks 0 and 1 is empty (StateInfo({}, {}, {})). And this is caused by empty optim_state if you look at the for loop in the beginning of _all_gather_optim_state().

2025-02-22 17:02:42,060 - root - DEBUG - @tcm: In _all_gather_optim_state(): optim_state: {}
2025-02-22 17:02:42,060 - root - DEBUG - @tcm: In _gather_orig_param_state(): optim_state: {}
# Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers
    # across ranks
    for optim_state_key in all_optim_state_keys:
        param_key: Union[str, int, None] = optim_state_key_to_param_key.get(
            optim_state_key, None
        )

        if param_key is None:
            assert use_orig_params, (
                "If use_orig_params is False, we must be able to find the "
                f"corresponding param id. {optim_state_key} {param_key}"
            )
            if not optim_state_key.is_fsdp_managed:
                continue

        if optim_state_key.is_fsdp_managed:
            # If there are multiple unflat_param_names (not use_orig_params),
            # they share the same FSDPParamInfo. So the first unflat_param_name
            # is sufficient to fetch the FSDPParamInfo.
            fqn = optim_state_key.unflat_param_names[0]
            fsdp_param_info = fqn_to_fsdp_param_info[fqn]
            if use_orig_params:
                state = (
                    {} if param_key is None else optim_state_dict["state"][param_key]
                )
                unflat_state = [
                    _gather_orig_param_state(
                        fsdp_param_info,
                        fqn,
                        state,
                        shard_state,
                    )
                ]

The problem is that param_key is None which leads to empty ā€˜stateā€™ when passed to _gather_orig_param_state():

state = ({} if param_key is None else optim_state_dict["state"][param_key])

The param_key is None because optim_state_key_to_param_key dictionary is empty:

2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {} # rank 0 or 1
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {_OptimStateKey(unflat_param_names=('lm_head.weight',), is_fsdp_managed=True): 2} # rank 2 or 3
2025-02-22 17:02:42,046 - root - DEBUG - @tcm: In _optim_state_dict(): optim_state_key_to_param_key: {_OptimStateKey(unflat_param_names=('model.mm_projector.0.weight',), is_fsdp_managed=True): 0, _OptimStateKey(unflat_param_names=('model.mm_projector.2.weight',), is_fsdp_managed=True): 1, _OptimStateKey(unflat_param_names=('model.mm_projector.0.bias',), is_fsdp_managed=True): 3, _OptimStateKey(unflat_param_names=('model.mm_projector.2.bias',), is_fsdp_managed=True): 4} # rank 2 or 3

To understand why the dict optim_state_key_to_param_key is empty, I looked into the function: pytorch/torch/distributed/fsdp/_optim_utils.py at v2.1.2 Ā· pytorch/pytorch Ā· GitHub
Here, if we look at the for loop in the beginning:

for param_key, param in param_key_to_param.items():
        # Do not include parameters without state to avoid empty mappings
        # just like in normal `torch.optim.Optimizer.state_dict()`
        if param_key not in optim_state_dict["state"]:
            continue

optim_state_dict["state"] is empty so the iteration is skipped, causing optim_state_key_to_param_key to not be updated.

2025-02-22 17:02:42,041 - root - DEBUG - @tcm: In _map_param_key_to_optim_keys(): optim_state_dict["state"]: {} # in empty ranks such as 0
2025-02-22 17:02:28,436 - root - DEBUG - @tcm: In _map_param_key_to_optim_keys(): optim_state_dict["state"]: {0: {'step': tensor(1.), 'exp_avg': tensor([-6.2440e-07,  6.8065e-07, -2.2726e-06,  ...,  3.1220e-07,          1.6088e-06, -1.5047e-07], device='cuda:1'), 'exp_avg_sq': tensor([3.8987e-14, 4.6328e-14, 5.1646e-13,  ..., 9.7467e-15, 2.5882e-13, 2.2642e-15],  device='cuda:1')}...

So the problem is optim_state_dict being empty when passed into _optim_state_dict().

2025-02-22 17:02:28,304 - root - DEBUG - @tcm: In FSDP.optim_state_dict(): optim_state_dict: None

and initialized through:

if optim_state_dict is None:
            optim_state_dict = optim.state_dict()
save_fsdp_optimizer(self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir)

So in torch.distributed.fsdp._optim_utils.py, the following function is used to initialize optim_state_dict: pytorch/torch/optim/optimizer.py at main Ā· pytorch/pytorch Ā· GitHub

To start with, this is the architecture of the model that Iā€™m trying to fine-tune:

FullyShardedDataParallel(
  (_fsdp_wrapped_module): CambrianLlamaForCausalLM(
    (model): CambrianLlamaModel(
      (embed_tokens): Embedding(128256, 3072)
      (layers): ModuleList(
        (0-27): 28 x FullyShardedDataParallel(
          (_fsdp_wrapped_module): LlamaDecoderLayer(
            (self_attn): LlamaSdpaAttention(
              (q_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (k_proj): Linear(in_features=3072, out_features=1024, bias=False)
              (v_proj): Linear(in_features=3072, out_features=1024, bias=False)
              (o_proj): Linear(in_features=3072, out_features=3072, bias=False)
              (rotary_emb): LlamaRotaryEmbedding()
            )
            (mlp): LlamaMLP(
              (gate_proj): Linear(in_features=3072, out_features=8192, bias=False)
              (up_proj): Linear(in_features=3072, out_features=8192, bias=False)
              (down_proj): Linear(in_features=8192, out_features=3072, bias=False)
              (act_fn): SiLU()
            )
            (input_layernorm): LlamaRMSNorm()
            (post_attention_layernorm): LlamaRMSNorm()
          )
        )
      )
      (norm): LlamaRMSNorm()
      (rotary_emb): LlamaRotaryEmbedding()
      (mm_projector): Sequential(
        (0): Linear(in_features=1024, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=3072, bias=True)
      )
      (mm_projector_aux_0): Sequential(
        (0): Linear(in_features=1152, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=1024, bias=True)
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (mm_projector_aux_1): Sequential(
        (0): Linear(in_features=1536, out_features=1024, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=1024, out_features=1024, bias=True)
        (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      )
      (vision_sampler_0): VisionTokenSampler(
        (layers): ModuleList(
          (0-2): 3 x VisionCrossAttentionLayer(
            (proj_context): Linear(in_features=1024, out_features=1024, bias=False)
            (proj_in): Linear(in_features=2048, out_features=1024, bias=False)
            (proj_out): MLP(
              (linear_1): Linear(in_features=1024, out_features=1024, bias=False)
              (act): GELU(approximate='none')
              (linear_2): Linear(in_features=1024, out_features=1024, bias=False)
            )
            (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (cross_attn): MultiKVCrossAttention(
              (q_proj): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (k_proj_0): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (v_proj_0): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (k_proj_1): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (v_proj_1): Sequential(
                (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
                (1): Linear(in_features=1024, out_features=1024, bias=False)
              )
              (o_proj): Linear(in_features=1024, out_features=1024, bias=False)
            )
          )
        )
      )
      (lm_head): Linear(in_features=3072, out_features=128256, bias=False)
    )
  )
)

There are two FSDP instances: the entire model and the LlamaDecoderLayer layer. In my fine-tuning script, this is how I configure FSDP options when used in Trainer:

--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \

Since the error is related to saving FSDP optimizer states, I would like to provide the following method in LLaVATrainer class, which is a subclass from Huggingface Trainer:

class LLaVATrainer(Trainer):
    def create_optimizer(self):
        """
        Setup the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        Trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        # pyre-fixme[16]: `Trainer` has no attribute `model`.
        opt_model = self.model
        # if self.args.unfreeze_mm_vision_tower:
        #     opt_model.get_model().vision_tower_aux_list = nn.ModuleList(opt_model.get_vision_tower_aux_list())
        #     self.param_to_name = map_params_to_module_names([opt_model])
        # pyre-fixme[16]: `Trainer` has no attribute `optimizer`.
        if self.optimizer is None:
            decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
            decay_parameters = [name for name in decay_parameters if "bias" not in name]
            # pyre-fixme[16]: `Trainer` has no attribute `mm_projector_lr`.
            assert not (self.args.mm_projector_lr and self.args.mm_vision_sampler_lr)
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p
                            for n, p in opt_model.named_parameters()
                            if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [
                            p
                            for n, p in opt_model.named_parameters()
                            if (n not in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                    },
                ]
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
                self.args
            )

            self.optimizer = optimizer_cls(
                optimizer_grouped_parameters, **optimizer_kwargs
            )
        return self.optimizer

In the model codebase, the create_optimizer() method above creates param groups and other params. When I print out the steps in this create_optimizer() method, the output is as follows:

2025-02-22 17:01:20,658 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:                                                                                 tensor([], device='cuda:0', dtype=torch.bfloat16, requires_grad=True)
2025-02-22 17:01:20,695 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:
tensor([], device='cuda:1', dtype=torch.bfloat16, requires_grad=True)
...
2025-02-22 17:01:20,690 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:                                                                                 tensor([ 0.0103,  0.0090,  0.0134,  ...,  0.0049, -0.0025, -0.0052],                                                                                                                                                      device='cuda:2', dtype=torch.bfloat16, requires_grad=True)
2025-02-22 17:01:20,691 - root - DEBUG - @tcm: In LLaVATrainer: n = _fsdp_wrapped_module.lm_head.weight, p = Parameter containing:                                                                                 tensor([-0.0099, -0.0302, -0.0054,  ..., -0.0038, -0.0027, -0.0015],
       device='cuda:3', dtype=torch.bfloat16, requires_grad=True)

As can be seen, on ranks 0 and 1, the LM head layer wrapped in FSDP unit has no params in the tensor, but they are present on ranks 3 and 4. So I think the error might stem from this fsdp sharding where the same layer lm head is sharded on ranks 2 and 3 but empty on ranks 0 and 1. Therefore, my question is:

I would like to ask why FSDP in Trainer shards a layer such that itā€™s empty on certain ranks, possibly leading to the error above?

Iā€™m been trying my best to find out the root cause and fix this error but the codebase is large and complex, so I am seeking help from the community.

Thanks in advance.

P/S: Since the error and the way I approach it is quite technically deep, please allow me to ping @sgugger and maybe several other experts on this problem! Thank you!

1 Like

Hi. I think this post is not going to have anybody comment so if somebody knows where I could ask this with potential help, please tell me, I would appreciate it. Thanks!

1 Like

I think there is a possibility of a bug in the Transformers or accelerate libraries, but I thought it might be a problem with the PyTorch specifications. In any case, raising the issue in the github issue section of one of the libraries will certainly be conveyed to the developers.

1 Like

I think FSDP is in charge of the accelerate library (or PyTorch itself), but the function that was being executed when the error occurred is from the Transformers library, so I think there is no problem with the Transformers library github.

1 Like

Thank you for your direction.
Iā€™m not sure how many factors are contributing to the error above, but Iā€™ve just found out that a possible reason is self.state in torch.optim.Optimizer, as can be seen here: pytorch/torch/optim/optimizer.py at v2.1.2 Ā· pytorch/pytorch Ā· GitHub

self.state stores the optimizerā€™s states so when I print it out among ranks, thereā€™s one rank where self.state is missing:

2025-02-24 04:16:18,691 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {})

For other ranks, self.state is available:

2025-02-24 04:16:07,667 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {Parameter containing:                                                                      tensor([-0.0099, -0.0302, -0.0054,  ..., -0.0038, -0.0027, -0.0015],
       device='cuda:3', requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([-2.1631e-11, -6.3862e-11, -3.7596e-11,  ..., -1.6094e-12,                                                                              7.6394e-12, -6.6093e-12], device='cuda:3'), 'exp_avg_sq': tensor([4.6788e-23, 4.0783e-22, 1.4135e-22,  ..., 2.5902e-25, 5.8360e-24,                                                                               4.3683e-24], device='cuda:3')}})
2025-02-24 04:16:07,667 - root - DEBUG - @tcm: In optim.Optimizer.state_dict(): self.state=defaultdict(<class 'dict'>, {Parameter containing:                                                                      tensor([ 0.0103,  0.0090,  0.0134,  ...,  0.0049, -0.0025, -0.0052],
       device='cuda:2', requires_grad=True): {'step': tensor(1.), 'exp_avg': tensor([-1.9513e-08,  5.0979e-09,  1.5206e-08,  ..., -2.0051e-10,                                                                              9.6136e-11, -1.6412e-10], device='cuda:2'), 'exp_avg_sq': tensor([3.8075e-17, 2.5989e-18, 2.3122e-17,  ..., 4.0205e-21, 9.2421e-22,                                                                               2.6935e-21], device='cuda:2')}})

Since the optimizer I used is AdamW, when I look into adamw.py I see that for state to be initialized non-empty, it has to go through the first step() operation: pytorch/torch/optim/adamw.py at v2.1.2 Ā· pytorch/pytorch Ā· GitHub
where self._init_group will do the initialization. However, Iā€™m really quite unsure the mechanism behind initializing and manipulating optim states.