We use roberta base model for a binary classification problem, and I tried quantization using torch, onnxruntime and Optimum, both torch and onnxruntime quantization model recall drop around 5%, but Optimum quantization model recall drop around 10%, for all these models, precision is same as the original model.
Code:
ort_model = ORTModelForSequenceClassification.from_pretrained(model_path, export=True)
quantizer = ORTQuantizer.from_pretrained(ort_model)
dqconfig = AutoQuantizationConfig.arm64(
is_static=False,
per_channel=True,
)
model_quantized_path = quantizer.quantize(
save_dir=onnx_model_path,
quantization_config=dqconfig,
)
I tried different params in AutoQuantizationConfig.arm64
function, but the result is similar.
By the way, how can I determine which nodes to exclude in quantization, is there a tool?
Hi @Ivan1999,
To check out which nodes are sensitive to quantization, you can create a quantization preprocessor (as done here), to iteratively exclude operations which can result in significant drop in accuracy when quantized (such as GELU, LayerNorm, softmax)
@echarlaix , I tried, not work,
# Create a quantization preprocessor to determine the nodes to exclude when applying static quantization
quantization_preprocessor = QuantizationPreprocessor()
# Exclude the nodes constituting LayerNorm
quantization_preprocessor.register_pass(ExcludeLayerNormNodes())
# Exclude the nodes constituting GELU
quantization_preprocessor.register_pass(ExcludeGeLUNodes())
# Exclude the residual connection Add nodes
quantization_preprocessor.register_pass(ExcludeNodeAfter("Add", "Add"))
# Exclude the Add nodes following the Gather operator
quantization_preprocessor.register_pass(ExcludeNodeAfter("Gather", "Add"))
# Exclude the Add nodes followed by the Softmax operator
quantization_preprocessor.register_pass(ExcludeNodeFollowedBy("Add", "Softmax"))
Can you show the script you used to quantize the model with ONNNXRuntime (i.e. without Optimum) please?
It’s been a long time, after some configuration I made both onnx and optimum quantization recall drop around 6%.