Skip to content

Commit e5a1ee0

Browse files
jerryzh168facebook-github-bot
authored andcommitted
[quant][graphmode] Refactor fusion to use the new Pattern format (pytorch#68770)
Summary: Pull Request resolved: pytorch#68770 Previous fusion only works for a sequnce of ops, which is not general enough for fusion patterns that is defined by a subgraph, this PR refactors that to make it more general Test Plan: ``` python test/test_quantization.py TestFuseFx ``` Imported from OSS Reviewed By: vkuzo Differential Revision: D32602637 fbshipit-source-id: a7897c62081b9d71c67fb56e78484cf68deaacf6
1 parent 1433160 commit e5a1ee0

File tree

6 files changed

+112
-39
lines changed

6 files changed

+112
-39
lines changed

torch/ao/quantization/fuser_method_mappings.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
import torch.nn.intrinsic as nni
33

44
from typing import Union, Callable, Tuple, Dict, Optional, Type
5+
from torch.ao.quantization.utils import Pattern
56

67
from torch.ao.quantization.utils import get_combined_dict
78

9+
810
def fuse_conv_bn(conv, bn):
911
r"""Given the conv and bn modules, fuses them and returns the fused module
1012
@@ -104,7 +106,7 @@ def fuse_linear_bn(linear, bn):
104106
else:
105107
return nn.utils.fusion.fuse_linear_bn_eval(linear, bn)
106108

107-
DEFAULT_OP_LIST_TO_FUSER_METHOD : Dict[Tuple, Union[nn.Sequential, Callable]] = {
109+
DEFAULT_OP_LIST_TO_FUSER_METHOD: Dict[Tuple, Union[nn.Sequential, Callable]] = {
108110
(nn.Conv1d, nn.BatchNorm1d): fuse_conv_bn,
109111
(nn.Conv1d, nn.BatchNorm1d, nn.ReLU): fuse_conv_bn_relu,
110112
(nn.Conv2d, nn.BatchNorm2d): fuse_conv_bn,
@@ -131,3 +133,39 @@ def get_fuser_method(op_list, additional_fuser_method_mapping=None):
131133
fuser_method = all_mappings.get(op_list, None)
132134
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_list)
133135
return fuser_method
136+
137+
def reverse2(f):
138+
return lambda x, y: f(y, x)
139+
140+
def reverse3(f):
141+
def reversed(x, w):
142+
y, z = w
143+
return f(z, y, x)
144+
return reversed
145+
146+
DEFAULT_PATTERN_TO_FUSER_METHOD: Dict[Pattern, Union[nn.Sequential, Callable]] = {
147+
(nn.BatchNorm1d, nn.Conv1d): reverse2(fuse_conv_bn),
148+
(nn.ReLU, (nn.BatchNorm1d, nn.Conv1d)): reverse3(fuse_conv_bn_relu),
149+
(nn.BatchNorm2d, nn.Conv2d): reverse2(fuse_conv_bn),
150+
(nn.ReLU, (nn.BatchNorm2d, nn.Conv2d)): reverse3(fuse_conv_bn_relu),
151+
(nn.BatchNorm3d, nn.Conv2d): reverse2(fuse_conv_bn),
152+
(nn.ReLU, (nn.BatchNorm3d, nn.Conv3d)): reverse3(fuse_conv_bn_relu),
153+
(nn.ReLU, nn.Conv1d): reverse2(nni.ConvReLU1d),
154+
(nn.ReLU, nn.Conv2d): reverse2(nni.ConvReLU2d),
155+
(nn.ReLU, nn.Conv3d): reverse2(nni.ConvReLU3d),
156+
(nn.BatchNorm1d, nn.Linear): reverse2(fuse_linear_bn),
157+
(nn.ReLU, nn.Linear): reverse2(nni.LinearReLU),
158+
(nn.ReLU, nn.BatchNorm2d): reverse2(nni.BNReLU2d),
159+
(nn.ReLU, nn.BatchNorm3d): reverse2(nni.BNReLU3d),
160+
}
161+
162+
def get_fuser_method_new(op_pattern, fuser_method_mapping=None):
163+
""" This will be made defult after we deparate the get_fuser_method
164+
Would like to implement this first and have a separate PR for deprecation
165+
"""
166+
if fuser_method_mapping is None:
167+
fuser_method_mapping = DEFAULT_PATTERN_TO_FUSER_METHOD
168+
169+
fuser_method = fuser_method_mapping.get(op_pattern, None)
170+
assert fuser_method is not None, "did not find fuser method for: {} ".format(op_pattern)
171+
return fuser_method

torch/ao/quantization/fx/fuse.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Dict, Any
2-
31
from torch.fx import (
42
GraphModule,
53
Node,
@@ -19,11 +17,9 @@
1917

2018
from .fusion_patterns import * # noqa: F401,F403
2119

22-
from typing import Callable, Tuple
23-
from typing import Optional
24-
25-
from .quantization_types import Pattern
20+
from typing import Callable, Tuple, Dict, Any, Optional, List
2621

22+
from .quantization_types import Pattern, NodePattern
2723

2824
class Fuser:
2925
def fuse(
@@ -50,11 +46,18 @@ def load_arg(a):
5046
return map_arg(a, lambda node: env[node.name])
5147

5248
for node in input_graph.nodes:
53-
root_node, obj = fusion_pairs.get(node.name, (None, None))
54-
if root_node is node:
49+
maybe_last_node, pattern, matched_node_pattern, obj = \
50+
fusion_pairs.get(node.name, (None, None, None, None))
51+
if maybe_last_node is node:
5552
assert obj is not None
56-
env[node.name] = obj.fuse(self, load_arg, fuse_custom_config_dict)
57-
elif root_node is None:
53+
# TODO: currently we hard code the root node, which only works for
54+
# a tuple of two nodes, we want to make this more general to
55+
# support more complex patterns
56+
root_node = matched_node_pattern[-1] # type: ignore[index]
57+
env[node.name] = obj.fuse(
58+
self, load_arg, root_node, matched_node_pattern, # type: ignore[arg-type]
59+
fuse_custom_config_dict)
60+
elif maybe_last_node is None:
5861
env[node.name] = self.fused_graph.node_copy(node, load_arg)
5962
# node matched in patterns and is not root is removed here
6063

@@ -65,25 +68,30 @@ def load_arg(a):
6568
def _find_matches(
6669
self, root: GraphModule, graph: Graph,
6770
patterns: Dict[Pattern, Callable]
68-
) -> Dict[str, Tuple[Node, FuseHandler]]:
71+
) -> Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]]:
6972
modules = dict(root.named_modules())
70-
match_map : Dict[str, Tuple[Node, FuseHandler]] = {} # node name -> (root_node, match_value)
73+
match_map : Dict[str, Tuple[Node, Pattern, NodePattern, FuseHandler]] = {} # node name -> (root_node, match_value)
7174

72-
def apply_match(pattern, node, match):
75+
def apply_match(pattern, node, match, matched_node_pattern):
7376
if isinstance(pattern, tuple):
7477
s, *args = pattern
75-
apply_match(s, node, match)
78+
current_node_pattern: List[Node] = []
79+
apply_match(s, node, match, current_node_pattern)
7680
for subpattern, arg in zip(args, node.args):
77-
apply_match(subpattern, arg, match)
81+
apply_match(subpattern, arg, match, current_node_pattern)
82+
matched_node_pattern.append(tuple(current_node_pattern))
7883
else:
7984
# the first pattern matches will take precedence
8085
if node.name not in match_map:
81-
match_map[node.name] = match
86+
matched_node_pattern.append(node)
87+
root_node, pattern, handler = match
88+
match_map[node.name] = (root_node, pattern, matched_node_pattern, handler)
8289

8390
for node in reversed(graph.nodes):
8491
if node.name not in match_map:
8592
for pattern, value in patterns.items():
93+
matched_node_pattern: List[Node] = []
8694
if is_match(modules, node, pattern):
87-
apply_match(pattern, node, (node, value(self, node)))
95+
apply_match(pattern, node, (node, pattern, value(self, node)), matched_node_pattern)
8896

8997
return match_map

torch/ao/quantization/fx/fusion_patterns.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44
register_fusion_pattern,
55
)
66
from .utils import _parent_name
7-
from .quantization_types import QuantizerCls
7+
from .quantization_types import QuantizerCls, NodePattern
88
from ..fuser_method_mappings import get_fuser_method
9+
from ..fuser_method_mappings import get_fuser_method_new
910
from abc import ABC, abstractmethod
1011
from typing import Any, Callable, Dict
12+
from .match_utils import MatchAllNode
1113

1214
# ---------------------
1315
# Fusion Pattern Registrations
@@ -21,7 +23,11 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
2123
pass
2224

2325
@abstractmethod
24-
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
26+
def fuse(self,
27+
quantizer: QuantizerCls,
28+
load_arg: Callable,
29+
root_node: Node,
30+
matched_node_pattern: NodePattern,
2531
fuse_custom_config_dict: Dict[str, Any]) -> Node:
2632
pass
2733

@@ -61,7 +67,11 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
6167
self.conv_or_linear_node = node
6268
self.conv_or_linear = quantizer.modules[self.conv_or_linear_node.target]
6369

64-
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
70+
def fuse(self,
71+
quantizer: QuantizerCls,
72+
load_arg: Callable,
73+
root_node: Node,
74+
matched_node_pattern: NodePattern,
6575
fuse_custom_config_dict: Dict[str, Any]) -> Node:
6676
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
6777
op_list = []
@@ -116,23 +126,36 @@ def __init__(self, quantizer: QuantizerCls, node: Node):
116126
self.module_node = node
117127
self.module = quantizer.modules[self.module_node.target]
118128

119-
def fuse(self, quantizer: QuantizerCls, load_arg: Callable,
129+
def fuse(self, quantizer: QuantizerCls,
130+
load_arg: Callable,
131+
root_node: Node,
132+
matched_node_pattern: NodePattern,
120133
fuse_custom_config_dict: Dict[str, Any]) -> Node:
121134
additional_fuser_method_mapping = fuse_custom_config_dict.get("additional_fuser_method_mapping", {})
122-
op_list = []
135+
assert root_node.op == "call_module", "Expecting module node to be a call_module Node"
136+
root_module = quantizer.modules[root_node.target]
137+
assert len(additional_fuser_method_mapping) == 0, "Fusion implementation is "
138+
"undergoing changes, additoinal_fuser_method_mapping is not supported currently."
139+
def get_module(n):
140+
if n.op == "call_module":
141+
return quantizer.modules[n.target]
142+
elif n.op == "call_function" and n.target == torch.nn.functional.relu:
143+
relu = torch.nn.ReLU()
144+
relu.training = root_module.training
145+
return relu
146+
return MatchAllNode
147+
148+
matched_modules = tuple(map(get_module, matched_node_pattern))
123149
# since relu can be used multiple times, we'll need to create a relu module for each match
124-
if self.relu_node.op == 'call_module':
125-
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
126-
else:
127-
# TODO: get inplace argument from functional
128-
relu = torch.nn.ReLU()
129-
relu.training = self.module.training
130-
op_list.append(relu)
131-
op_list.append(self.module)
132150

133-
op_list.reverse()
134-
op_type_list = tuple(type(m) for m in op_list)
135-
module_parent_name, module_name = _parent_name(self.module_node.target)
136-
fuser_method = get_fuser_method(op_type_list, additional_fuser_method_mapping)
137-
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
138-
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
151+
def get_type(m):
152+
return type(m)
153+
154+
matched_module_types = tuple(map(get_type, matched_modules))
155+
module_parent_name, module_name = _parent_name(root_node.target)
156+
fuser_method = get_fuser_method_new(matched_module_types)
157+
# TODO: change the signature for fuser_method to take matched module patterns
158+
# as input
159+
fused_module = fuser_method(*matched_modules)
160+
setattr(quantizer.modules[module_parent_name], module_name, fused_module)
161+
return quantizer.fused_graph.node_copy(root_node, load_arg)

torch/ao/quantization/fx/match_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
MatchResult = Tuple[Node, List[Node], Optional[Pattern], QuantizeHandler,
2323
QConfigAny]
2424

25+
# TODO: maybe rename this to MatchInputNode
2526
class MatchAllNode:
2627
""" A node pattern that matches all nodes
2728
"""

torch/ao/quantization/fx/quantization_types.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1-
from typing import Any
1+
from typing import Any, Tuple, Union
2+
from torch.fx import Node
23
from ..utils import Pattern # noqa: F401
34

5+
NodePattern = Union[Tuple[Node, Node], Tuple[Node, Tuple[Node, Node]], Any]
6+
47
# This is the Quantizer class instance from torch/quantization/fx/quantize.py.
58
# Define separately to prevent circular imports.
69
# TODO(future PR): improve this.

torch/ao/quantization/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
# Type for fusion patterns, it can be more complicated than the following actually,
1111
# see pattern.md for docs
1212
# TODO: not sure if typing supports recursive data types
13-
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Callable, Callable]]
13+
Pattern = Union[Callable, Tuple[Callable, Callable], Tuple[Callable, Tuple[Callable, Callable]], Any]
1414

1515
def get_combined_dict(default_dict, additional_dict):
1616
d = default_dict.copy()

0 commit comments

Comments
 (0)