Accelerate FSDP config prompts

Hello. I’m using accelerate to leverage FSDP as per the tutorial here. I would like to finetune CodeBert using run_mlm_no_trainer.py on a custom dataset. I’ve run accelerate config but I’m not sure what the prompt → transformer_layer_cls_to_wrap entails - be it for my model or even in the examples provided.

cc @smangrul

Hello @ablam, the blog post is outdated as the FSDP features have been upgraded in PyTorch version 1.12.0. As such, all these new features have been integrated into HF Accelerate. For transformer-based models, PyTorch teams suggested using the transformer_auto_wrap policy. In this policy, the user has to specify the case-sensitive name of an encoder/decoder block comprising of the Multi-Head Attention layer followed by the Feedforward layer. For example, in the T5 model, T5Block is the name for the attention block used by the model for N such layers/blocks in the encoder and decoder. Similarly, for the BERT model, it is BertLayer and for GPT2 it is GPT2Block. Below is an example of the accelerate config for the bert-base-cased model :

compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: FULL_STATE_DICT
  fsdp_transformer_layer_cls_to_wrap: BertLayer
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false

You will have to print the model and check the name for the attention block and pass it as the value for transformer_layer_cls_to_wrap. I hope this helps.

2 Likes

Yes, it does. Found the corresponding reference in my model. Thank you!

1 Like

Hi,

I have a question about fsdp_transformer_layer_cls_to_wrap: How can I know the transformer layer for different models? If I want to accelerate LLaMa, can I directly set this parameter to LlamaLayer?

Hello @Colorful, the docs have been update here recently Fully Sharded Data Parallel (huggingface.co) answering your question. Specifically, the highlighted part below.