Hello. I’m using accelerate to leverage FSDP as per the tutorial here. I would like to finetune CodeBert using run_mlm_no_trainer.py on a custom dataset. I’ve run accelerate config
but I’m not sure what the prompt → transformer_layer_cls_to_wrap
entails - be it for my model or even in the examples provided.
cc @smangrul
Hello @ablam, the blog post is outdated as the FSDP features have been upgraded in PyTorch version 1.12.0. As such, all these new features have been integrated into HF Accelerate. For transformer-based models, PyTorch teams suggested using the transformer_auto_wrap
policy. In this policy, the user has to specify the case-sensitive name of an encoder/decoder block comprising of the Multi-Head Attention layer followed by the Feedforward layer. For example, in the T5 model, T5Block
is the name for the attention block used by the model for N such layers/blocks in the encoder and decoder. Similarly, for the BERT model, it is BertLayer
and for GPT2 it is GPT2Block
. Below is an example of the accelerate config for the bert-base-cased
model :
compute_environment: LOCAL_MACHINE
deepspeed_config: {}
distributed_type: FSDP
downcast_bf16: 'no'
fsdp_config:
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
fsdp_backward_prefetch_policy: BACKWARD_PRE
fsdp_offload_params: false
fsdp_sharding_strategy: 1
fsdp_state_dict_type: FULL_STATE_DICT
fsdp_transformer_layer_cls_to_wrap: BertLayer
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
mixed_precision: fp16
num_machines: 1
num_processes: 2
use_cpu: false
You will have to print the model and check the name for the attention block and pass it as the value for transformer_layer_cls_to_wrap
. I hope this helps.
Yes, it does. Found the corresponding reference in my model. Thank you!
Hi,
I have a question about fsdp_transformer_layer_cls_to_wrap
: How can I know the transformer layer for different models? If I want to accelerate LLaMa, can I directly set this parameter to LlamaLayer
?
Hello @Colorful, the docs have been update here recently Fully Sharded Data Parallel (huggingface.co) answering your question. Specifically, the highlighted part below.