Get number of parameters for different parts of a model

Hey there,

I know I can get the number of trainable parameters in a pytorch model by using sum(p.numel() for p in model.parameters()), but how can I get the count for the different parts of the model? For example for BertForMaskedLM I tried using the code with model.base_model.parameters() and model.cls.parameters() but the sum of the results are way above the ressult for simply using model.parameters().

I am sure I must be missing something very obvious here but I dont know what.

EDIT: Ah, I figured the shared/cloned weights present both in the embedding layer and decoder are counted once for the model total since only one instance of them has effectively to be trained?

Best
Johannes