Skip to content

Commit 5ded23c

Browse files
authored
Adds utility to replace Q/DQ ops with torchao quantized linear ops (#1967)
* up * up * up * up
1 parent dd7db9f commit 5ded23c

File tree

3 files changed

+300
-0
lines changed

3 files changed

+300
-0
lines changed

.github/workflows/torchao_experimental_test.yml

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ jobs:
4444
conda activate venv
4545
pytest torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py
4646
python torchao/experimental/tests/test_embedding_xbit_quantizer.py
47+
python torchao/experimental/tests/test_quant_passes.py
4748
- name: Run kernels/cpu/aarch64/tests
4849
run: |
4950
conda activate venv

torchao/experimental/quant_passes.py

+217
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import itertools
9+
from collections import defaultdict
10+
from typing import Callable, Optional
11+
12+
import torch
13+
from torch._export.passes.constant_folding import (
14+
ConstantFolder,
15+
replace_node_with_constant,
16+
)
17+
from torch.fx import subgraph_rewriter
18+
19+
20+
def constant_fold(
21+
gm: torch.fx.GraphModule,
22+
constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
23+
skip_constructors: bool = False,
24+
):
25+
with torch.utils._python_dispatch._disable_current_modes():
26+
# The ConstantFolder has a bug where it throws if dequantize_affine is not defined
27+
# TODO: fix upstream
28+
try:
29+
getattr(torch.ops.pt2e_quant, "dequantize_affine")
30+
except AttributeError:
31+
setattr(torch.ops.pt2e_quant, "dequantize_affine", None)
32+
33+
cf = ConstantFolder(gm, skip_constructors)
34+
cf.run()
35+
36+
for node, constant in cf.node_replacements.items():
37+
if constraint_fn is not None and not constraint_fn(node):
38+
continue
39+
replace_node_with_constant(gm, node, constant)
40+
41+
erased_params = []
42+
# Get all attr users by looking up the graph instead from node.users, because in this case
43+
# _tensor_constant0 and _tensor_constant0_1 are actually refereing to the same tensor.
44+
45+
# opcode name target args kwargs
46+
# ------------- ------------------- ---------------- --------------------------- --------
47+
# placeholder arg0_1 arg0 () {}
48+
# get_attr _tensor_constant0 state () {}
49+
# call_function add aten.add.Tensor (arg0_1, _tensor_constant0) {}
50+
# get_attr _tensor_constant0_1 state () {}
51+
# call_function add_ aten.add_.Tensor (_tensor_constant0_1, 1) {}
52+
# output output output ([add],) {}
53+
54+
get_attr_node_users = defaultdict(list)
55+
for node in gm.graph.nodes:
56+
if node.op == "get_attr":
57+
get_attr_node_users[node.target].extend(node.users.keys())
58+
for node in gm.graph.find_nodes(op="get_attr"):
59+
if node.op == "get_attr" and len(get_attr_node_users[node.target]) == 0:
60+
if hasattr(gm, node.target):
61+
delattr(gm, node.target)
62+
erased_params.append(node)
63+
for node in erased_params:
64+
gm.graph.erase_node(node)
65+
66+
gm.graph.eliminate_dead_code()
67+
gm.graph.lint()
68+
gm.recompile()
69+
70+
71+
def _get_q_dq_linear_patterns_replacements_and_filters(
72+
weight_bit_width, has_weight_zeros, target
73+
):
74+
glbs = globals()
75+
glbs["weight_bit_width"] = weight_bit_width
76+
glbs["target"] = target
77+
glbs["w_quant_min"] = -(1 << (weight_bit_width - 1))
78+
glbs["w_quant_max"] = (1 << (weight_bit_width - 1)) - 1
79+
glbs["a_quant_min"] = -128
80+
glbs["a_quant_max"] = 127
81+
glbs["a_mapping_type"] = "ASYMMETRIC"
82+
glbs["a_scale_dtype"] = torch.float32
83+
glbs["a_eps"] = None
84+
85+
lcls = {}
86+
87+
pattern_str = f"""
88+
def pattern(
89+
a, a_block_size, a_target_dtype, a_zero_point_dtype,
90+
w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype,
91+
bias):
92+
a_scale, a_zero_point = torch.ops.quant.choose_qparams_affine.default(
93+
a,
94+
a_mapping_type,
95+
a_block_size,
96+
a_target_dtype,
97+
a_quant_min,
98+
a_quant_max,
99+
a_eps,
100+
a_scale_dtype,
101+
a_zero_point_dtype,
102+
)
103+
a_int_data = torch.ops.quant.quantize_affine.default(
104+
a, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max,
105+
)
106+
dq_a = torch.ops.quant.dequantize_affine.default(
107+
a_int_data, a_block_size, a_scale, a_zero_point, a_target_dtype, a_quant_min, a_quant_max
108+
)
109+
dq_w = torch.ops.quant.dequantize_affine.default(
110+
w_int_data,
111+
w_block_size,
112+
w_scale,
113+
w_zero_point,
114+
w_target_dtype,
115+
w_quant_min,
116+
w_quant_max,
117+
{"'INT'" if has_weight_zeros else "'NONE'"}
118+
)
119+
return torch.ops.aten.linear.default(dq_a, dq_w, bias)
120+
"""
121+
exec(pattern_str, glbs, lcls)
122+
pattern = lcls["pattern"]
123+
124+
replacement_str = f"""
125+
def replacement(
126+
a, a_block_size, a_target_dtype, a_zero_point_dtype,
127+
w_int_data, w_block_size, w_scale, w_zero_point, w_target_dtype,
128+
bias,):
129+
n = w_int_data.size(0)
130+
k = a_block_size[-1]
131+
group_size = w_block_size[-1]
132+
out_shape = a.shape[:-1] + (n,)
133+
packed_weight = getattr(
134+
torch.ops.torchao,
135+
f"_pack_8bit_act_{weight_bit_width}bit_weight",
136+
)(
137+
w_int_data.to(torch.int8),
138+
w_scale.reshape(-1),
139+
{"w_zero_point.reshape(-1).to(torch.int8)" if has_weight_zeros else "None"},
140+
group_size,
141+
bias,
142+
target,
143+
)
144+
return getattr(
145+
torch.ops.torchao, f"_linear_8bit_act_{weight_bit_width}bit_weight"
146+
)(a.reshape(-1, k), packed_weight, group_size, n, k).reshape(out_shape)
147+
"""
148+
149+
exec(replacement_str, glbs, lcls)
150+
replacement = lcls["replacement"]
151+
152+
def match_filter(match, x, y):
153+
def get_val(name):
154+
node = [n for n in match.nodes_map if n.name == name][0]
155+
return match.nodes_map[node]
156+
157+
int_types = [torch.int8, torch.int16, torch.int32, torch.int64]
158+
159+
a_target_dtype = get_val("a_target_dtype")
160+
if a_target_dtype not in int_types:
161+
return False
162+
163+
a_zero_point_dtype = get_val("a_zero_point_dtype")
164+
if a_zero_point_dtype not in int_types:
165+
return False
166+
167+
# We only want a_block_size with shape [1, ..., 1, k]
168+
a_block_size = get_val("a_block_size")
169+
for d in a_block_size[0:-1]:
170+
if d != 1:
171+
print("a_block_size not [1, ..., 1, k]")
172+
return False
173+
174+
# We only want w_block_size with shape [1, group_size]
175+
w_block_size = get_val("w_block_size")
176+
if len(w_block_size) != 2 or w_block_size[0] != 1:
177+
return False
178+
179+
return True
180+
181+
return pattern, replacement, match_filter
182+
183+
184+
def replace_q_dq_patterns_with_quantized_linear_ops_pass(
185+
ep: torch.export.ExportedProgram,
186+
target=None,
187+
) -> torch.export.ExportedProgram:
188+
"""
189+
This replaces Q/DQ patterns with torchao quantized linear ops.
190+
It is intended for converting Q/DQ nodes exported with QDQLayout to using
191+
the lowbit quantized linear ops.
192+
"""
193+
# TODO: figure out how to do this with dynamic_shapes (not saved on EP for easy re-export)
194+
# See https://fb.workplace.com/groups/1028545332188949/permalink/1185289956514485/
195+
assert (
196+
len(ep.range_constraints) == 0
197+
), "ExportedProgram with range constraints are not supported"
198+
199+
# ep.module() unlifts the weight inputs, which we need for constant folding
200+
gm = ep.module()
201+
for weight_bit_width, has_weight_zeros in itertools.product(
202+
range(1, 9), [True, False]
203+
):
204+
pattern, replacement, match_filter = (
205+
_get_q_dq_linear_patterns_replacements_and_filters(
206+
weight_bit_width, has_weight_zeros, target
207+
)
208+
)
209+
subgraph_rewriter.replace_pattern_with_filters(
210+
gm, pattern, replacement, match_filters=[match_filter]
211+
)
212+
213+
# Constant fold evaluates and removes the packing ops
214+
constant_fold(gm)
215+
216+
# Re-export
217+
return torch.export.export(gm, *ep.example_inputs)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from torch.testing import FileCheck
11+
12+
from torchao.experimental.q_dq_layout import QDQLayout
13+
from torchao.experimental.quant_api import (
14+
Int8DynamicActivationIntxWeightConfig,
15+
)
16+
from torchao.experimental.quant_passes import (
17+
replace_q_dq_patterns_with_quantized_linear_ops_pass,
18+
)
19+
from torchao.quantization.granularity import PerGroup, PerRow
20+
from torchao.quantization.quant_api import quantize_
21+
22+
23+
class TestQuantPasses(unittest.TestCase):
24+
def test_replace_q_dq_patterns_with_quantized_linear_ops_pass(self):
25+
layers = []
26+
layer_to_weight_dtype = {}
27+
layer_to_has_weight_zeros = {}
28+
for weight_dtype in [getattr(torch, f"int{i}") for i in range(1, 9)]:
29+
for has_weight_zeros in [True, False]:
30+
for has_bias in [True, False]:
31+
idx = len(layers)
32+
layer_to_weight_dtype[idx] = weight_dtype
33+
layer_to_has_weight_zeros[idx] = has_weight_zeros
34+
layers.append(torch.nn.Linear(64, 64, bias=has_bias))
35+
activations = torch.randn(2, 1, 64, dtype=torch.float32)
36+
37+
model = torch.nn.Sequential(*layers)
38+
for idx in range(len(layers)):
39+
quantize_(
40+
model,
41+
Int8DynamicActivationIntxWeightConfig(
42+
weight_dtype=layer_to_weight_dtype[idx],
43+
# Test out different granularities
44+
granularity=PerGroup(32) if idx % 2 == 0 else PerRow(),
45+
has_weight_zeros=layer_to_has_weight_zeros[idx],
46+
layout=QDQLayout(),
47+
),
48+
lambda m, fqn: fqn == str(idx),
49+
)
50+
51+
eager_results = model(activations)
52+
exported = torch.export.export(model, (activations,), strict=True)
53+
exported = replace_q_dq_patterns_with_quantized_linear_ops_pass(exported)
54+
55+
# We should not find pack op because it gets constant folded
56+
FileCheck().check_not("torch.ops.torchao._pack_8bit_act").run(
57+
exported.graph_module.code
58+
)
59+
60+
# We should find len(layers) torchao linear ops
61+
FileCheck().check_count(
62+
"torch.ops.torchao._linear_8bit_act_", count=len(layers), exactly=True
63+
).run(exported.graph_module.code)
64+
65+
# We should not find Q/DQ ops
66+
FileCheck().check_not("torch.ops.quant.quantize_affine.default").run(
67+
exported.graph_module.code
68+
)
69+
FileCheck().check_not("torch.ops.quant.dequantize_affine.default").run(
70+
exported.graph_module.code
71+
)
72+
FileCheck().check_not("torch.ops.quant.choose_qparams_affine.default").run(
73+
exported.graph_module.code
74+
)
75+
76+
# Numerics should match
77+
exported_results = exported.module()(activations)
78+
self.assertTrue(torch.allclose(exported_results, eager_results))
79+
80+
81+
if __name__ == "__main__":
82+
unittest.main()

0 commit comments

Comments
 (0)