hey @ftian i had a chat with michael benayoun who ran into a similar issue while developing the quantization modules for the nn_pruning
library: https://github.com/huggingface/nn_pruning/tree/main/nn_pruning/modules
as a general advice, he recommends the following:
For static quantization / QAT, things are a bit different, you need to:
- Load the model with the proper model config
- Apply the same quantization to the model as it was previously done
- Load the state dict from the checkpoint on that modified model (at this point, every scale and zero_point should be loaded correctly)
Because we are saving the state_dict and not the graph itself, it is impossible to “guess” where the observers / fake quantizations / quantize nodes were located, so the second step is somehow inevitable (although I am working on graph mode quantization which might solve that). For quantized models (after torch.quantization.convert), I would recommend tracing the model with torchscript, at least that’s what I have done, as it provides anything needed to run inference which is usually the goal when a model was quantized.
of course this isn’t as simple as being able to load a quantized model with from_pretrained
so i’ll let @sgugger comment on whether this type of feature would make sense to include in transformers
itself