Hi @lewtun , thanks for the question!
Indeed all the linear layers (torch.nn.Linear) are replaced with custom modules that add scores matrices to accumulate the momentum for pruning.
As of now, we have no plan to include it more broadly in the transformers library even though it is fairly straight-forward to do it: replace all the torch.nn.Linear and change the forward call. I believe @madlag has some code to automatically do that on the fly, maybe he would be open to share about that?
Victor
1 Like