Deepspeed zero3 does not work with Diffusion Models. Does anyone know how to fix this?

here is detail: using Deepspeed zero stage3 finetune sd2, dimension error occurs · Issue #22705 · huggingface/transformers · GitHub
It appears that there may be a problem with partitioning the parameters.

train(args)
File “/workdir/fengyu05/501587/033c0dc6fc985d5dee8904327c53b497/src/diffuser/train_txt2img.py”, line 529, in train
encoder_hidden_states = text_encoder(batch[“input_ids”].to(accelerator.device))[0]
File “/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1194, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/conda/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py”, line 816, in forward
return self.text_model(
File “/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1194, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/conda/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py”, line 712, in forward
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
File “/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1194, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/conda/lib/python3.9/site-packages/transformers/models/clip/modeling_clip.py”, line 227, in forward
inputs_embeds = self.token_embedding(input_ids)
File “/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py”, line 1194, in _call_impl
return forward_call(*input, **kwargs)
File “/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/sparse.py”, line 160, in forward
return F.embedding(
File “/usr/local/conda/lib/python3.9/site-packages/torch/nn/functional.py”, line 2210, in embedding
return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
RuntimeError: ‘weight’ must be 2-D
Steps: 0%|