FSDP Auto Wrap does not work using `accelerate` in Multi-GPU Setup

Introduction

Hi,
I’m trying to fine-tune LLama3-8B with the accelerate package using FSDP and flash-attention. I followed the official guide to config my accelerate with FSDP setting (Fully Sharded Data Parallel). However, when use the model._no_split_modules option (meaning no defined fsdp_transformer_layer_cls_to_wrap), the following error appears "ValueError: Could not find the transformer layer class L in the model.". On the other hand, if I defined fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer in the FSDP configuration file, the model printout seem to suggest that the layer is not wrapped by FSDP.

Software Info

OS
Linux

python
3.10

cuda
12.4

packages:

torch==2.4.1
transformers==4.44.2
datasets==2.21.0
accelerate==0.34.0
flash-attn==2.6.3

Accelerate config setup:

compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: [ ]
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false

Result running accelerate launch

The print(model) shows the FSDP auto wrap is not successful and the model layers are not distributed across GPUs

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_emb): LlamaRotaryEmbedding()
  )
  (lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)