Skip to content

Commit 57c46af

Browse files
sanchitintelpytorchmergebot
authored andcommitted
[Inductor][CPU] Add torchao da8w8 pattern with sym quantized act & wgt (pytorch#142110)
### Summary Extends pytorch#142036 for Inductor pattern-matching pattern covered for torchao API `int8_dynamic_activation_int8_weight` in the following scenario (inference-only, freezing enabled) - - int8 quantized (symmetrically) activation (per token quantized). - Statically (so, scales are also constant. But then they would have been constant even in case of dynamic quantization due to constant weights, anyway) per-channel int8 quantized (symmetrically) weights (which are also constant because freezing is enabled). The pattern that's matched is `torch._intmm` -> convert to FP32/BF16 -> [optional expand for activation scale] ->`mul` -> `mul`. We don't check if the activation is dynamically quantized or whether the weights are statically quantized, though (since the implementation won't have have any side-effects even if that wouldn't be true). In practice, it also matches the smooth-quant int8 quantized linear pattern if its output is not reshaped (if activation is 2D). ### More details oneDNN int8 matmul supports application of per-channel weight scale but not a vector activation scale, which could be applied as a post op, but is currently unsupported in ATen. Bias addition (which could be supported with an add post-op) is also unfused. The fusion pattern used in this PR is `torch._intmm` -> convert to FP32/BF16 ->`mul`, which will be replaced by oneDNN qlinear op. The speedup over eager-mode is due to 2 reasons - 1. fusion of int8xint8 -> int32 GEMM, conversion to FP32/BF16 & application of weight scale. (In case of BF16, many intermediate conversions are also avoided). 2. weight is pre-packed & cached by Inductor, so a reorder is avoided at run-time. But, in the future, the whole pattern (including application of activation scale, which would be a mul post-op) + bias could be fused if corresponding support would be enabled in ATen. ### Verification Added UT in this PR ``` python test/inductor/test_mkldnn_pattern_matcher.py -v -k test_da8w8_sym_act_sym_wgt_with_int_mm ``` #### Corresponding torchao UTs 1. int8 Smoothquant legacy API - `TORCHINDUCTOR_FREEZING=1 TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" python test/integration/test_integration.py -v -k test_non_dynamically_quantizable_linear`. The difference from pytorch#139595 is that there are no reshapes of the linear output in this pattern. 2. int8 da8w8 - symmetrically quantized activation (dynamically) & statically quantized weights - ` TORCH_COMPILE_DEBUG=1 TORCH_LOGS="+inductor" TORCHINDUCTOR_FREEZING=1 python test/integration/test_integration.py -v -k test_int8_dynamic_quant_subclass_api_0_cpu` Pull Request resolved: pytorch#142110 Approved by: https://github.com/leslie-fang-intel, https://github.com/jgong5 ghstack dependencies: pytorch#142036
1 parent b731ced commit 57c46af

File tree

2 files changed

+144
-43
lines changed

2 files changed

+144
-43
lines changed

test/inductor/test_mkldnn_pattern_matcher.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3375,6 +3375,78 @@ def matcher_check_fn():
33753375
compile_options={"dynamic": dynamic},
33763376
)
33773377

3378+
@skipIfNoDynamoSupport
3379+
@skipIfNoONEDNN
3380+
# TODO: investigate options of torch.compile in fbcode
3381+
@unittest.skipIf(IS_FBCODE, "Failing in fbcode")
3382+
@parametrize("has_bias", [True, False])
3383+
@parametrize("dtype", [torch.float, torch.bfloat16])
3384+
@parametrize("dynamic", [True, False])
3385+
@parametrize("reshape_a", [True, False])
3386+
def test_da8w8_sym_act_sym_wgt_with_int_mm(
3387+
self, has_bias, dtype, dynamic, reshape_a
3388+
):
3389+
r"""
3390+
This testcase check if we can match the int8_dynamic_activation_int8_weight int8 linear pattern from torchao,
3391+
when activation is symmetrically quantized dynamically & weights are symmetrically quantized (statically)
3392+
The pattern is:
3393+
(no bias) _int_mm -> convert_element_type -> ([expand_a] -> mul) -> mul
3394+
or
3395+
(with bias) pattern_no_bias -> add
3396+
Expansion of the scale of activation is optional.
3397+
The pattern depiction doesn't mean that convert_element_type output is fed into expand_a as input,
3398+
but simply that activation scale may be applied after an expand operation on it.
3399+
"""
3400+
if dtype == torch.bfloat16 and not torch.ops.mkldnn._is_mkldnn_bf16_supported():
3401+
return
3402+
M = 32
3403+
in_feature = 32
3404+
out_feature = 64
3405+
q_min, q_max = -32, 31
3406+
3407+
class Mod(torch.nn.Module):
3408+
def __init__(self, dtype: torch.dtype, has_bias: bool):
3409+
super().__init__()
3410+
self.dtype = dtype
3411+
self.has_bias = has_bias
3412+
self.b = torch.randint(
3413+
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
3414+
)
3415+
self.a_scale = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
3416+
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
3417+
self.b_scale = self.b_scale.to(dtype)
3418+
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
3419+
3420+
def forward(self, a):
3421+
if reshape_a:
3422+
a_reshaped = a.reshape(-1, a.size(-1))
3423+
else:
3424+
a_reshaped = a
3425+
c = torch._int_mm(a_reshaped, self.b)
3426+
c = c.to(self.dtype)
3427+
a_scale = self.a_scale.expand(c.shape)
3428+
c = c * a_scale
3429+
c = c * self.b_scale
3430+
if self.has_bias:
3431+
c = c + self.bias
3432+
return c
3433+
3434+
mod = Mod(dtype, has_bias).eval()
3435+
a = torch.randint(q_min, q_max, [M, in_feature], dtype=torch.int8)
3436+
3437+
def matcher_check_fn():
3438+
self.assertEqual(
3439+
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
3440+
)
3441+
3442+
self._test_common(
3443+
mod,
3444+
(a,),
3445+
matcher_check_fn=matcher_check_fn,
3446+
check_autocast=dtype,
3447+
compile_options={"dynamic": dynamic},
3448+
)
3449+
33783450

33793451
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
33803452
class TestDynamicPatternMatcher(TestPatternMatcherBase):

torch/_inductor/fx_passes/quantization.py

Lines changed: 72 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2812,50 +2812,53 @@ def _register_smooth_quant_int_mm_pattern():
28122812

28132813
# When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist
28142814
# When torch.compile'ing with dynamic=False, they don't exist
2815-
def get_pattern_no_bias(expand_a_scale: bool):
2815+
def get_pattern_no_bias(expand_a_scale: bool, reshape_a: bool = True):
28162816
return CallFunction(
2817-
aten.reshape.default,
2817+
aten.mul.Tensor,
28182818
CallFunction(
28192819
aten.mul.Tensor,
28202820
CallFunction(
2821-
aten.mul.Tensor,
2821+
prims.convert_element_type.default,
28222822
CallFunction(
2823-
prims.convert_element_type.default,
2824-
CallFunction(
2825-
aten._int_mm.default,
2826-
CallFunction(
2827-
aten.reshape.default,
2828-
KeywordArg("a"),
2829-
KeywordArg("in_shape"),
2830-
),
2831-
KeywordArg("b"),
2832-
),
2833-
KeywordArg("dtype"),
2834-
),
2835-
(
2823+
aten._int_mm.default,
28362824
CallFunction(
2837-
aten.expand.default,
2838-
KeywordArg("x_scale"),
2839-
Arg(),
2825+
aten.reshape.default,
2826+
KeywordArg("a"),
2827+
KeywordArg("in_shape"),
28402828
)
2841-
if expand_a_scale
2842-
else KeywordArg("x_scale")
2829+
if reshape_a
2830+
else KeywordArg("a"),
2831+
KeywordArg("b"),
28432832
),
2833+
KeywordArg("dtype"),
2834+
),
2835+
(
2836+
CallFunction(
2837+
aten.expand.default,
2838+
KeywordArg("x_scale"),
2839+
Arg(),
2840+
)
2841+
if expand_a_scale
2842+
else KeywordArg("x_scale")
28442843
),
2845-
KeywordArg("w_scale"),
28462844
),
2847-
KeywordArg("out_shape_no_bias"),
2845+
KeywordArg("w_scale"),
2846+
)
2847+
2848+
def _with_outer_reshape(pattern):
2849+
return CallFunction(
2850+
aten.reshape.default, pattern, KeywordArg("out_shape_no_bias")
28482851
)
28492852

28502853
# for torch.compile(dynamic=False)
2851-
pattern_no_bias_1 = get_pattern_no_bias(expand_a_scale=False)
2854+
pattern_no_bias_1 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=False))
28522855
pattern_with_bias_1 = CallFunction(
28532856
aten.add.Tensor,
28542857
pattern_no_bias_1,
28552858
KeywordArg("bias"),
28562859
)
28572860
# for torch.compile(dynamic=True)
2858-
pattern_no_bias_2 = get_pattern_no_bias(expand_a_scale=True)
2861+
pattern_no_bias_2 = _with_outer_reshape(get_pattern_no_bias(expand_a_scale=True))
28592862
pattern_with_bias_2 = CallFunction(
28602863
aten.reshape.default,
28612864
CallFunction(
@@ -2870,15 +2873,26 @@ def get_pattern_no_bias(expand_a_scale: bool):
28702873
KeywordArg("out_shape_with_bias"),
28712874
)
28722875

2876+
# The following patterns are for torchao int8_dynamic_activation_int8_weight linear,
2877+
# when both activation and weights are symmetrically quantized.
2878+
# In practice, though, they may also match smooth-quant pattern when a 2D input shape would be used.
2879+
# Since add is not currently being used as a oneDNN post-op, but is unfused, we don't need these patterns with bias.
2880+
# Ideally, we should add mul + add post-op support in ATen int8 oneDNN linear op.
2881+
pattern1_with_no_outer_or_act_reshape = get_pattern_no_bias(
2882+
expand_a_scale=False, reshape_a=False
2883+
)
2884+
pattern2_with_no_outer_or_act_reshape = get_pattern_no_bias(
2885+
expand_a_scale=True, reshape_a=False
2886+
)
2887+
28732888
def _validate_pattern(match: Match):
2874-
if len(match.nodes) not in [6, 7, 10]:
2889+
if len(match.nodes) not in [4, 5, 6, 7, 10]:
28752890
return False
28762891
# Make sure weight is a constant
2877-
if match.nodes[1].target != aten._int_mm.default:
2892+
aten_int_mm_node = filter_nodes(match.nodes, aten._int_mm.default)[0]
2893+
if not isinstance(aten_int_mm_node.args[1], torch.fx.node.Node):
28782894
return False
2879-
if not isinstance(match.nodes[1].args[1], torch.fx.node.Node):
2880-
return False
2881-
if match.nodes[1].args[1].op != "get_attr":
2895+
if aten_int_mm_node.args[1].op != "get_attr":
28822896
return False
28832897

28842898
if len(match.nodes) == 10:
@@ -2902,6 +2916,8 @@ def _validate_pattern(match: Match):
29022916
pattern_with_bias_2: 0,
29032917
pattern_no_bias_1: 1,
29042918
pattern_with_bias_1: 1,
2919+
pattern1_with_no_outer_or_act_reshape: 2,
2920+
pattern2_with_no_outer_or_act_reshape: 2,
29052921
}
29062922
for pattern, pass_number in pattern_to_pass_number.items():
29072923

@@ -2978,9 +2994,13 @@ def _int_mm_weight_prepack(match: Match, *args, **kwargs):
29782994
else:
29792995
# onednn.qlinear does not support per-channel quantization of x
29802996
# so in this case, we have to apply x scale and add bias ourselves after qlinear
2981-
x_reshaped = match.graph.call_function(
2982-
aten.reshape.default, args=(x, kwargs["in_shape"])
2983-
)
2997+
in_shape = kwargs.get("in_shape", None)
2998+
if in_shape is None:
2999+
x_reshaped = x
3000+
else:
3001+
x_reshaped = match.graph.call_function(
3002+
aten.reshape.default, args=(x, in_shape)
3003+
)
29843004
new_args = (
29853005
x_reshaped,
29863006
1.0, # x_scale
@@ -3003,23 +3023,32 @@ def _int_mm_weight_prepack(match: Match, *args, **kwargs):
30033023
new_out_node = match.graph.call_function(
30043024
aten.mul.Tensor, args=(new_linear_node, x_scale)
30053025
)
3026+
30063027
# Add bias and reshape
3007-
out_shape = kwargs.get(
3008-
"out_shape_with_bias", kwargs["out_shape_no_bias"]
3028+
has_outer_reshape = (
3029+
kwargs.get("out_shape_with_bias", None) is not None
3030+
or kwargs.get("out_shape_no_bias", None) is not None
30093031
)
3032+
3033+
if has_outer_reshape:
3034+
out_shape = kwargs.get(
3035+
"out_shape_with_bias", kwargs["out_shape_no_bias"]
3036+
)
30103037
if bias is not None:
30113038
new_out_node = match.graph.call_function(
30123039
aten.add.Tensor, args=(new_out_node, bias)
30133040
)
3014-
new_out_node = match.graph.call_function(
3015-
aten.reshape.default,
3016-
args=(new_out_node, out_shape),
3017-
)
3041+
if has_outer_reshape:
3042+
new_out_node = match.graph.call_function(
3043+
aten.reshape.default,
3044+
args=(new_out_node, out_shape), # type: ignore[possibly-undefined]
3045+
)
30183046
else:
3019-
new_out_node = match.graph.call_function(
3020-
aten.reshape.default,
3021-
args=(new_out_node, out_shape),
3022-
)
3047+
if has_outer_reshape:
3048+
new_out_node = match.graph.call_function(
3049+
aten.reshape.default,
3050+
args=(new_out_node, out_shape), # type: ignore[possibly-undefined]
3051+
)
30233052
out_node.replace_all_uses_with(new_out_node)
30243053
new_out_node.meta.update(out_node.meta)
30253054
for node in reversed(match.nodes):

0 commit comments

Comments
 (0)