Pegasus Model Weights Compression/Pruning

hey @SteveMama, i had a closer look at from_pretrained and indeed it does not support loading quantized models because quantization changes the model’s state_dict (for example by introducing scale and zero_point parameters).

however, i think there is a work around that involves the following steps:

1. quantize your model

import torch
from transformers import PegasusTokenizer, PegasusForConditionalGeneration

# load fine-tuned model and tokenizer
model_ckpt = "google/pegasus-cnn_dailymail"
model = PegasusForConditionalGeneration.from_pretrained(model_ckpt)
# quantize model
quantized_model = torch.quantization.quantize_dynamic(
    model, {torch.nn.Linear}, dtype=torch.qint8
)

2. save the quantized model’s state_dict and config to disk

# save config
quantized_model.config.save_pretrained("pegasus-quantized-config")
# save state dict
quantized_state_dict = quantized_model.state_dict()
torch.save(quantized_state_dict, "pegasus-quantized.pt")

3. in your heroku app, create a dummy model using the saved config

from transformers import AutoConfig
# load config and dummy model
config = AutoConfig.from_pretrained("pegasus-quantized-config")
dummy_model = PegasusForConditionalGeneration(config)

4. quantize dummy model and load state dict

reconstructed_quantized_model = torch.quantization.quantize_dynamic(
    dummy_model, {torch.nn.Linear}, dtype=torch.qint8
)
reconstructed_quantized_model.load_state_dict(quantized_state_dict)

from here you should be able to run reconstructed_quantized_model.generate and produce coherent outputs - let me know if you cannot.

an alternative is to use torchscript for the serialization as done in this pytorch tutorial (link), but i am not very familiar with torchscript and it seems somewhat complicated because you need to trace out the forward pass with some dummy inputs (seems to work for BERT but less sure about seq2seq models like pegasus)

hope that helps!

2 Likes