You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
### Summary
Extends pytorch#142036 for Inductor pattern-matching pattern covered for torchao API `int8_dynamic_activation_int8_weight` in the following scenario (inference-only, freezing enabled) -
- int8 quantized (symmetrically) activation (per token quantized).
- Statically (so, scales are also constant. But then they would have been constant even in case of dynamic quantization due to constant weights, anyway) per-channel int8 quantized (symmetrically) weights (which are also constant because freezing is enabled).
The pattern that's matched is `torch._intmm` -> convert to FP32/BF16 -> [optional expand for activation scale] ->`mul` -> `mul`.
We don't check if the activation is dynamically quantized or whether the weights are statically quantized, though (since the implementation won't have have any side-effects even if that wouldn't be true).
In practice, it also matches the smooth-quant int8 quantized linear pattern if its output is not reshaped (if activation is 2D).
### More details
oneDNN int8 matmul supports application of per-channel weight scale but not a vector activation scale, which could be applied as a post op, but is currently unsupported in ATen. Bias addition (which could be supported with an add post-op) is also unfused.
The fusion pattern used in this PR is `torch._intmm` -> convert to FP32/BF16 ->`mul`, which will be replaced by oneDNN qlinear op.
The speedup over eager-mode is due to 2 reasons -
1. fusion of int8xint8 -> int32 GEMM, conversion to FP32/BF16 & application of weight scale. (In case of BF16, many intermediate conversions are also avoided).
2. weight is pre-packed & cached by Inductor, so a reorder is avoided at run-time.
But, in the future, the whole pattern (including application of activation scale, which would be a mul post-op) + bias could be fused if corresponding support would be enabled in ATen.
### Verification
Added UT in this PR
```
python test/inductor/test_mkldnn_pattern_matcher.py -v -k test_da8w8_sym_act_sym_wgt_with_int_mm
```
#### Corresponding torchao UTs
1. int8 Smoothquant legacy API - `TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" python test/integration/test_integration.py -v -k test_non_dynamically_quantizable_linear`.
The difference from pytorch#139595 is that there are no reshapes of the linear output in this pattern.
2. int8 da8w8 - symmetrically quantized activation (dynamically) & statically quantized weights - ` TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" TORCHINDUCTOR_FREEZING=1 python test/integration/test_integration.py -v -k test_int8_dynamic_quant_subclass_api_0_cpu`
Pull Request resolved: pytorch#142110
Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5
ghstack dependencies: pytorch#142036
0 commit comments