Distributed fine-tuning with frozen embedding layers

Hi HF community,

I want to fine-tune a HF checkpoint while freezing the transformer’s embedding layers.
By doing so (description below), I am getting the following when using DDP and torchrun:
AssertionError: expects all parameters to have same requires_grad

The same script while not freezing any weights works fine.
The same script while freezing weights but only using one GPU (no DDP) works too.

This is how I instantiate the model before passing it to the trainer:

model = AlbertForMaskedLM.from_pretrained("Rostlab/prot_albert")
if freeze_embeddings:
   for param in model.albert.embeddings.parameters():
      param.requires_grad = False

Main hyper-parameters for the fine-tuning CLI:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --standalone --nnodes=1 --nproc_per_node=8 finetune_protbert.py --fp16 --fp16_opt_level O1 --sharded_ddp zero_dp_2

Is this issue known ? Any hints for solving it please ?

Thanks !