Skip to content

Commit 54241a9

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][fx] Add support for fused modules in _convert_do_not_use (pytorch#67245)
Summary: Pull Request resolved: pytorch#67245 Add support for fused modules in the new convert path, including linear-relu, conv{1-3}d-relu and their qat versions, also tested with trt (conv2d-relu and linear-relu) Test Plan: ``` python test/fx2trt/test_quantize_fx.py TestQuantizeFxTRTOps.test_linear_relu_module python test/fx2trt/test_quantize_fx.py TestQuantizeFxTRTOps.test_conv_relu_module ``` Imported from OSS Reviewed By: vkuzo Differential Revision: D31919724 fbshipit-source-id: 7e5c96eba30706f7989da680aa3443159847bdfd
1 parent 91971df commit 54241a9

File tree

3 files changed

+193
-45
lines changed

3 files changed

+193
-45
lines changed

test/fx2trt/test_quant_trt.py

Lines changed: 62 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,13 @@
2020
from torch.testing._internal.common_quantization import (
2121
QuantizationTestCase,
2222
)
23+
import torch.nn.functional as F
24+
2325
from torch.testing._internal.common_cuda import TEST_CUDA
2426
from torch.testing._internal.common_utils import run_tests
2527
from torch.testing._internal.common_quantization import NodeSpec as ns
2628
import unittest
29+
import itertools
2730

2831
def lower_to_trt(model, inputs, shape_ranges):
2932
""" Lower a quantized model to TensorRT
@@ -92,36 +95,60 @@ def _test_module(
9295
trt_mod(*inputs_cuda)
9396

9497

95-
def test_conv(self):
96-
class Conv2dModule(torch.nn.Module):
97-
def __init__(self):
98+
def test_conv_relu_module(self):
99+
conv_module = {1 : torch.nn.Conv1d, 2 : torch.nn.Conv2d, 3 : torch.nn.Conv3d}
100+
101+
conv1d_input = torch.rand(1, 3, 10)
102+
conv2d_input = torch.rand(1, 3, 10, 10)
103+
conv3d_input = torch.rand(1, 3, 10, 10, 10)
104+
conv_input = {1: conv1d_input, 2: conv2d_input, 3: conv3d_input}
105+
106+
class ConvNdModule(torch.nn.Module):
107+
def __init__(self, dim, has_relu=False, f_relu=False):
98108
super().__init__()
99-
self.conv = torch.nn.Conv2d(3, 3, 3)
109+
self.conv = conv_module[dim](3, 3, 3).float()
110+
if has_relu:
111+
if f_relu:
112+
self.relu = F.relu
113+
else:
114+
self.relu = torch.nn.ReLU()
115+
else:
116+
self.relu = torch.nn.Identity()
100117

101118
def forward(self, x):
102-
return self.conv(x)
103-
104-
conv2d_input = torch.rand(1, 3, 224, 224)
105-
no_convert = {
106-
ns.call_function(torch.quantize_per_tensor): 2,
107-
ns.call_method("dequantize"): 2
108-
}
109-
self._test_module(
110-
Conv2dModule(),
111-
[conv2d_input],
112-
[((1, 3, 224, 224),
113-
(5, 3, 224, 224),
114-
(10, 3, 224, 224))],
115-
no_convert=no_convert)
116-
117-
def test_linear(self):
119+
return self.relu(self.conv(x))
120+
121+
# just testing conv2d since conv1d and conv3d are not supported in fx2trt
122+
for dim, has_relu, f_relu in itertools.product([2], [True, False], [True, False]):
123+
# when has_relu=False, we have torch.nn.Identity, which would introduce
124+
# extra quant-dequat pair
125+
no_convert = {
126+
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
127+
ns.call_method("dequantize"): 2 + int(not has_relu),
128+
}
129+
self._test_module(
130+
ConvNdModule(dim, has_relu, f_relu),
131+
[conv_input[dim]],
132+
[((1, *conv_input[dim].shape[1:]),
133+
(5, *conv_input[dim].shape[1:]),
134+
(10, *conv_input[dim].shape[1:]))],
135+
no_convert=no_convert)
136+
137+
def test_linear_relu_module(self):
118138
class LinearModule(torch.nn.Module):
119-
def __init__(self):
139+
def __init__(self, has_relu=False, f_relu=False):
120140
super().__init__()
121-
self.linear = torch.nn.Linear(5, 10)
141+
self.linear = torch.nn.Linear(5, 10).float()
142+
if has_relu:
143+
if f_relu:
144+
self.relu = F.relu
145+
else:
146+
self.relu = torch.nn.ReLU()
147+
else:
148+
self.relu = torch.nn.Identity()
122149

123150
def forward(self, x):
124-
return self.linear(x)
151+
return self.relu(self.linear(x))
125152

126153
linear_input = torch.rand(8, 5)
127154

@@ -130,15 +157,18 @@ def forward(self, x):
130157
(5, 5),
131158
(10, 5))
132159
]
133-
no_convert = {
134-
ns.call_function(torch.quantize_per_tensor): 2,
135-
ns.call_method("dequantize"): 2,
136-
}
137-
self._test_module(
138-
LinearModule(),
139-
[linear_input],
140-
shape_ranges,
141-
no_convert=no_convert)
160+
for has_relu, f_relu in itertools.product([True, False], [True, False]):
161+
# when has_relu=False, we have torch.nn.Identity, which would introduce
162+
# extra quant-dequat pair
163+
no_convert = {
164+
ns.call_function(torch.quantize_per_tensor): 2 + int(not has_relu),
165+
ns.call_method("dequantize"): 2 + int(not has_relu),
166+
}
167+
self._test_module(
168+
LinearModule(has_relu, f_relu),
169+
[linear_input],
170+
shape_ranges,
171+
no_convert=no_convert)
142172

143173
def test_ops(self):
144174
class M(torch.nn.Module):

torch/ao/quantization/fx/_convert_do_not_use.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,34 @@
3535

3636
from .convert import restore_state
3737

38+
# these are tuples so that they can work with isinstance(module, tuple_of_classes)
39+
WEIGHTED_MODULE_CLASSES = (
40+
torch.nn.Linear,
41+
torch.nn.Conv1d,
42+
torch.nn.Conv2d,
43+
torch.nn.Conv3d
44+
)
45+
46+
FUSED_MODULE_CLASSES = (
47+
torch.nn.intrinsic.LinearReLU,
48+
torch.nn.intrinsic.ConvReLU1d,
49+
torch.nn.intrinsic.ConvReLU2d,
50+
torch.nn.intrinsic.ConvReLU3d,
51+
)
52+
53+
QAT_MODULE_CLASSES = (
54+
torch.nn.qat.Linear,
55+
torch.nn.qat.Conv2d,
56+
torch.nn.qat.Conv3d,
57+
torch.nn.intrinsic.qat.LinearReLU,
58+
torch.nn.intrinsic.qat.ConvBn2d,
59+
torch.nn.intrinsic.qat.ConvBnReLU2d,
60+
torch.nn.intrinsic.qat.ConvReLU2d,
61+
torch.nn.intrinsic.qat.ConvBn3d,
62+
torch.nn.intrinsic.qat.ConvBnReLU3d,
63+
torch.nn.intrinsic.qat.ConvReLU3d
64+
)
65+
3866
def _convert_do_not_use(
3967
model: GraphModule, is_reference: bool = False,
4068
convert_custom_config_dict: Dict[str, Any] = None,
@@ -64,7 +92,7 @@ def _convert_do_not_use(
6492
patterns, node_name_to_scope, prepare_custom_config_dict, observed_node_names = restore_state(model)
6593
qconfig_map: Dict[str, QConfigAny] = model._qconfig_map # type: ignore[assignment]
6694

67-
assert is_reference, "convert2 only supports reference option"
95+
assert is_reference, "_convert_do_not_use only supports reference option"
6896

6997
# mapping from fully qualified module name to module instance
7098
# for example,
@@ -167,24 +195,54 @@ def replace_observer_with_quantize_dequantize_node(graph: Graph, node: Node, mod
167195
elif node.op == "call_module":
168196
if is_activation_post_process(modules[node.target]):
169197
replace_observer_with_quantize_dequantize_node(model.graph, node, modules)
170-
elif type(modules[node.target]) in [
171-
torch.nn.Linear,
172-
torch.nn.Conv1d,
173-
torch.nn.Conv2d,
174-
torch.nn.Conv3d]:
175-
fmodule = modules[node.target]
176-
qconfig = fmodule.qconfig
198+
elif type(modules[node.target]) in set(
199+
WEIGHTED_MODULE_CLASSES).union(QAT_MODULE_CLASSES).union(FUSED_MODULE_CLASSES):
200+
# TODO: refactor this part to a function
201+
original_module = modules[node.target]
202+
qconfig = original_module.qconfig
177203

178204
is_observed = node.name in observed_node_names
179205
is_weight_quantized = weight_is_statically_quantized(qconfig)
180206
# TODO: rename weight_is_statically_quantized to weight_is_int8_quantized
181-
if qconfig is not None and is_observed and is_weight_quantized:
207+
if qconfig is None or not is_observed or not is_weight_quantized:
208+
continue
209+
210+
float_module = original_module
211+
fused_module = None
212+
if isinstance(
213+
original_module,
214+
QAT_MODULE_CLASSES):
215+
# case 1. converting qat module to
216+
# a float module, we need to attch
217+
# weight fake_quant to the module,
218+
# weight fake_quant is assumed to be run during
219+
# QAT so we don't need to run it again here
220+
float_module = original_module.to_float() # type: ignore[operator]
221+
# change qat conv to conv
222+
parent_name, name = _parent_name(node.target)
223+
setattr(modules[parent_name], name, float_module)
224+
if isinstance(float_module, torch.nn.intrinsic._FusedModule):
225+
fused_module = float_module
226+
float_module = fused_module[0]
227+
weight_post_process = original_module.weight_fake_quant
228+
else:
229+
# case 2. converting a float module/fused float module
230+
# to float module, we need to attach
231+
# weight observer to the conv module and run it
232+
# with conv weight
233+
if isinstance(original_module, torch.nn.intrinsic._FusedModule):
234+
fused_module = original_module
235+
float_module = fused_module[0] # type: ignore[index]
236+
assert qconfig is not None
182237
weight_post_process = qconfig.weight()
183238
# run weight observer
184-
weight_post_process(fmodule.weight) # type: ignore[operator]
185-
weight_qparams = get_qparam_dict(weight_post_process)
186-
ref_qmodule_cls = get_static_quant_module_class(type(fmodule), is_reference=True)
187-
ref_qmodule = ref_qmodule_cls.from_float(fmodule, weight_qparams)
239+
weight_post_process(float_module.weight) # type: ignore[operator]
240+
weight_qparams = get_qparam_dict(weight_post_process)
241+
ref_qmodule_cls = get_static_quant_module_class(type(float_module), is_reference=True)
242+
ref_qmodule = ref_qmodule_cls.from_float(float_module, weight_qparams)
243+
if fused_module is not None:
244+
fused_module[0] = ref_qmodule
245+
else:
188246
parent_name, name = _parent_name(node.target)
189247
setattr(modules[parent_name], name, ref_qmodule)
190248

torch/ao/quantization/fx/backend_config_dict/tensorrt.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,26 +33,86 @@ def get_tensorrt_backend_config_dict():
3333
weighted_op_qint8_dtype_config,
3434
]
3535
}
36+
# TODO: maybe make "pattern" to be a list of patterns
37+
# TODO: current patterns are the ones after fusion, we will want to expose fusion
38+
# here as well in the future, maybe we need to
39+
# linear_relu_mm_config = {
40+
# "pattern": (torch.nn.ReLU, torch.nn.Linear),
41+
# "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
42+
# "dtype_configs": [
43+
# weighted_op_qint8_dtype_config,
44+
# ]
45+
# }
46+
# linear_relu_mf_config = {
47+
# "pattern": (torch.nn.functional.relu, torch.nn.Linear),
48+
# "observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
49+
# "dtype_configs": [
50+
# weighted_op_qint8_dtype_config,
51+
# ]
52+
# }
53+
54+
linear_relu_fused_config = {
55+
"pattern": torch.nn.intrinsic.LinearReLU,
56+
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
57+
"dtype_configs": [
58+
weighted_op_qint8_dtype_config,
59+
]
60+
}
3661
conv_module_config = {
3762
"pattern": torch.nn.Conv2d,
3863
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
3964
"dtype_configs": [
4065
weighted_op_qint8_dtype_config,
4166
]
4267
}
68+
conv_relu_1d_fused_config = {
69+
"pattern": torch.nn.intrinsic.ConvReLU1d,
70+
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
71+
"dtype_configs": [
72+
weighted_op_qint8_dtype_config,
73+
]
74+
}
75+
conv_relu_2d_fused_config = {
76+
"pattern": torch.nn.intrinsic.ConvReLU2d,
77+
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
78+
"dtype_configs": [
79+
weighted_op_qint8_dtype_config,
80+
]
81+
}
82+
conv_relu_3d_fused_config = {
83+
"pattern": torch.nn.intrinsic.ConvReLU3d,
84+
"observation_type": ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT,
85+
"dtype_configs": [
86+
weighted_op_qint8_dtype_config,
87+
]
88+
}
4389
cat_config = {
4490
"pattern": torch.cat,
4591
"observation_type": ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
4692
"dtype_configs": [
4793
non_weighted_op_qint8_dtype_config,
4894
]
4995
}
96+
identity_config = {
97+
"pattern": torch.nn.Identity,
98+
"observation_type": ObservationType.OUTPUT_SHARE_OBSERVER_WITH_INPUT,
99+
"dtype_configs": [
100+
non_weighted_op_qint8_dtype_config,
101+
]
102+
}
50103
return {
51104
# optional
52105
"name": "tensorrt",
53106
"configs": [
54107
linear_module_config,
108+
linear_relu_fused_config,
55109
conv_module_config,
110+
# conv1d is not supported in fx2trt
111+
# conv_relu_1d_fused_config,
112+
conv_relu_2d_fused_config,
113+
# conv3d is not supported in fx2trt
114+
# conv_relu_3d_fused_config,
56115
cat_config,
116+
identity_config,
57117
]
58118
}

0 commit comments

Comments
 (0)