Hi, I am trying out ZeRO-style parallelism on large-scale model training using Facebook’s implementation, FairScale, instead of DeepSpeed for the ZeRO implementation. Specifically, I am hoping to apply ZeRO-3 to large transformer models. FairScale’s FSDP module allows users to either (1) directly wrap the transformer model with FSDP for easier usage or (2) wrap the transformer model layer-wise for optimal parallelization and memory savings (example usage).
I have tried out both wrapping approaches to vision models like resnet, because I can either import the model from torchvision and wrap it directly, or do per-layer wrapping by finding coded implementations of resnet architectures and modifying them directly. I have tried out directly wrapping transformers imported from huggingface. My question is that for such transformers, would it be possible to access the coded model architecture to enable per-layer FSDP wrapping?
Thank you in advance!