Introduction
Hi,
I’m trying to fine-tune LLama3-8B with the accelerate
package using FSDP and flash-attention. I followed the official guide to config my accelerate with FSDP setting (Fully Sharded Data Parallel). However, when use the model._no_split_modules
option (meaning no defined fsdp_transformer_layer_cls_to_wrap
), the following error appears "ValueError: Could not find the transformer layer class L in the model."
. On the other hand, if I defined fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
in the FSDP configuration file, the model printout seem to suggest that the layer is not wrapped by FSDP.
Software Info
OS
Linux
python
3.10
cuda
12.4
packages:
torch==2.4.1
transformers==4.44.2
datasets==2.21.0
accelerate==0.34.0
flash-attn==2.6.3
Accelerate config setup:
compute_environment: LOCAL_MACHINE
debug: false
distributed_type: FSDP
downcast_bf16: 'no'
enable_cpu_affinity: false
fsdp_config:
fsdp_activation_checkpointing: false
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch: BACKWARD_PRE
fsdp_cpu_ram_efficient_loading: true
fsdp_forward_prefetch: false
fsdp_offload_params: false
fsdp_sharding_strategy: FULL_SHARD
fsdp_state_dict_type: SHARDED_STATE_DICT
fsdp_sync_module_states: true
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
fsdp_use_orig_params: true
machine_rank: 0
main_training_function: main
mixed_precision: bf16
num_machines: 1
num_processes: 4
rdzv_backend: static
same_network: true
tpu_env: [ ]
tpu_use_cluster: false
tpu_use_sudo: false
use_cpu: false
Result running accelerate launch
The print(model)
shows the FSDP auto wrap is not successful and the model layers are not distributed across GPUs
LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(128256, 4096)
(layers): ModuleList(
(0-31): 32 x LlamaDecoderLayer(
(self_attn): LlamaFlashAttention2(
(q_proj): Linear(in_features=4096, out_features=4096, bias=False)
(k_proj): Linear(in_features=4096, out_features=1024, bias=False)
(v_proj): Linear(in_features=4096, out_features=1024, bias=False)
(o_proj): Linear(in_features=4096, out_features=4096, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
(up_proj): Linear(in_features=4096, out_features=14336, bias=False)
(down_proj): Linear(in_features=14336, out_features=4096, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
(post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
)
)
(norm): LlamaRMSNorm((4096,), eps=1e-05)
(rotary_emb): LlamaRotaryEmbedding()
)
(lm_head): Linear(in_features=4096, out_features=128256, bias=False)
)