Why are embedding / pooler layers excluded from pruning comparisons?

Hi @VictorSanh,

In your Saving PruneBERT notebook I noticed that you only save the encoder and head when comparing the effects of pruning / quantisation. For example, here you save the original dense model as follows:

# Saving the original (encoder + classifier) in the standard torch.save format

dense_st = {name: param for name, param in model.state_dict().items() 
                            if "embedding" not in name and "pooler" not in name}
torch.save(dense_st, 'dbg/dense_squad.pt',)
dense_mb_size = os.path.getsize("dbg/dense_squad.pt")

My question is: why are the embedding and pooled layers excluded from the size comparison between the BERT-base model and its pruned / quantised counterpart?

Naively, I would have thought that if I care about the amount of storage my model requires, then I would include all layers in the size calculation.


The QA model actually only needs the qa-head, the pooler is just decorative (it’s not even trained). Start and end of spans are predicted directly from the sequence of hidden state. This explains why I am not saving the pooler.
As for the embedding, I’m just fine-pruning the encoder, and the embedding modules stay fixed at their pre-trained values. So I am mostly interested in comparing the compression ratio of the encoder (since the rest is fixed).
Hope that makes sense.

1 Like

Thanks for the answer @VictorSanh - this makes perfect sense!

Hi @VictorSanh, I have a follow up question about the Saving PruneBERT notebook.

As far as I can tell, you rely on weight quantization in order to be able to use the CSR format on integer-valued weights - is this correct?

My question is whether it is possible to show the memory compression benefits of fine-pruning without quantizing the model first?

What I’d like to do is quantify the memory reduction of BERT-base vs your PruneBERT model, so that one can clearly see that X% comes from pruning, Y% from quantization and so on.


The notebook you are playing with is only applying the weight quantization. It is taking as input the fine-pruned (pruned during fine-tuning) model, so to see the impact of the pruning (compression), simply count the number of non-zero values (in the encoder). That should give you the compression rate of pruning!

Thanks for the clarification!

Counting the number of non-zero values is a good idea to get the compression rate, but what I’d usually do to quantify the size on disk (e.g. in MB) is save the encoder’s state_dict and get the size as follows:

    state_dict = {name: param for name, param in model.state_dict().items() if "embedding" not in name and "pooler" not in name}
    tmp_path = Path("model.pt")
    torch.save(state_dict, tmp_path)
    # Calculate size in megabytes
    size_mb = Path(tmp_path).stat().st_size / (1024 * 1024)

Now, my understanding is that if I load a fine-pruned model as follows

model = BertForQuestionAnswering.from_pretrained("huggingface/prunebert-base-uncased-6-finepruned-w-distil-squad")

then the model is dense, so I don’t see any compression gains on disk when I save the state_dict - is this correct?

If yes, then do you know if there’s a way to save the state_dict of a fine-pruned model to disk in a way that reflects the compression gains from a sparse encoder?


Ooooh yeah sorry for the confusion.
As far as I know (I think I tried), you can use the torch.sparse tensors representations which will decompose a sparse tensor into its CSR format (location of non-zero values + these non-zero values). It should give you a MB compression gain.
The reason why I encoded the CSR format “by hand” is that sparse quantized tensors don’t exist yet in PyTorch so I had to do the quantization and the CSR format on top.

Thanks for the tip about torch.sparse: from the docs it seems to use the COO format which should also work well :grinning_face_with_smiling_eyes:

And thanks for clarifying the reason for encoding the CSR format by hand - when I find a solution to the torch > 1.5 issue, I’ll expand the text accordingly!