Fix int8 zero-point overflow and conv bias_scale in eager quantized ref kernels#20655
Conversation
…ef kernels Summary: Fix two numerical bugs in the eager Cadence reference kernels (`ref_implementations.py`) that made them diverge from the deployed C++ kernels. 1. Int8 zero-point overflow: the dequant subtracted the zero-point while the tensor was still int8, so a negative zero-point could overflow and wrap. We now upcast before subtracting. Affects `quantized_add`, `quantized_mul`, `quantized_linear`, `quantized_matmul`, `quantized_conv`, and `quantized_relu`. 2. Conv bias_scale: `quantized_conv_per_tensor` added a pre-scaled bias onto an unscaled integer convolution accumulation, leaving the output off by ~`1/bias_scale`. We now add the integer bias pre-scale and dequantize the whole accumulation by `bias_scale`. Also corrects the uint8 dtype-check error messages in `quantized_add`. Differential Revision: D110220645
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20655
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 08ef0e5 with merge base d54a0c0 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@abeakkas has exported this pull request. If you are a Meta employee, you can view the originating Diff in D110220645. |
This PR needs a
|
Summary:
Fix two numerical bugs in the eager Cadence reference kernels (
ref_implementations.py) that made them diverge from the deployed C++ kernels.Int8 zero-point overflow: the dequant subtracted the zero-point while the tensor was still int8, so a negative zero-point could overflow and wrap. We now upcast before subtracting. Affects
quantized_add,quantized_mul,quantized_linear,quantized_matmul,quantized_conv, andquantized_relu.Conv bias_scale:
quantized_conv_per_tensoradded a pre-scaled bias onto an unscaled integer convolution accumulation, leaving the output off by ~1/bias_scale. We now add the integer bias pre-scale and dequantize the whole accumulation bybias_scale.Also corrects the uint8 dtype-check error messages in
quantized_add.Differential Revision: D110220645