Skip to content

Commit 537bb9d

Browse files
pkchmn-robot
authored andcommitted
Allow specifying ops that the regularizer does not pass during traversal.
PiperOrigin-RevId: 248683238
1 parent f67a608 commit 537bb9d

File tree

7 files changed

+104
-17
lines changed

7 files changed

+104
-17
lines changed

README.md

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,14 @@ models with batch norm; it requires that batch norm scale is enabled.
130130
* *Latency* optimizes for the estimated inference latency of the network, based
131131
on the specific hardware characteristics.
132132

133-
## Example: Adding a FLOPs Regularizer
133+
## Examples
134+
135+
### Adding a FLOPs Regularizer
134136

135137
The example below demonstrates how to use MorphNet to reduce the number of FLOPs
136-
in your model.
138+
in your model. In this example, the regularizer will traverse the graph
139+
starting with `logits`, and will not go past any op whose name matches the regex
140+
`/images.*`; this allows to specify the subgraph for MorphNet to optimize.
137141

138142
```python
139143
from morph_net.network_regularizers import flop_regularizer
@@ -142,7 +146,7 @@ from morph_net.tools import structure_exporter
142146
logits = build_model()
143147

144148
network_regularizer = flop_regularizer.GammaFlopsRegularizer(
145-
[logits.op], gamma_threshold=1e-3)
149+
[logits.op], input_boundary=[images, labels], gamma_threshold=1e-3)
146150
regularization_strength = 1e-10
147151
regularizer_loss = (network_regularizer.get_regularization_term() * regularization_strength)
148152

morph_net/framework/op_regularizer_manager.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import print_function
66

77
import collections
8+
89
from morph_net.framework import concat_and_slice_regularizers
910
from morph_net.framework import constant_op_regularizer
1011
from morph_net.framework import grouping_regularizers
@@ -48,6 +49,7 @@ def __init__(
4849
create_grouping_regularizer=grouping_regularizers.MaxGroupingRegularizer,
4950
force_group=None,
5051
regularizer_blacklist=None,
52+
input_boundary=None,
5153
iteration_limit=ITERATION_LIMIT):
5254
"""Creates an instance of OpRegularizerManager.
5355
@@ -84,6 +86,8 @@ def __init__(
8486
multiple patterns in a single regex.
8587
regularizer_blacklist: List of regex for ops that should not be
8688
regularized.
89+
input_boundary: A list of ops that represent the input boundary of the
90+
subgraph being regularized (input boundary is not regularized).
8791
iteration_limit: Integer iteration limit for OpRegularizerManager to
8892
finish analyzing the network. If the limit is reached, it is assumed
8993
that OpRegularizerManager got stuck in a loop.
@@ -115,7 +119,7 @@ def __init__(
115119

116120
# Start DFS from outputs to find all source ops.
117121
tf.logging.info('OpRegularizerManager starting analysis from: %s.', ops)
118-
self._dfs_for_source_ops(ops)
122+
self._dfs_for_source_ops(ops, input_boundary)
119123
tf.logging.info('OpRegularizerManager found %d ops and %d sources.',
120124
len(self._all_ops), len(self._op_deque))
121125

@@ -581,18 +585,25 @@ def _get_source_slices(self, sizes, op_slices):
581585
size_index += 1
582586
return is_source
583587

584-
def _dfs_for_source_ops(self, ops):
588+
def _dfs_for_source_ops(self, ops, input_boundary=None):
585589
"""Performs DFS from ops and finds source ops to process.
586590
587591
Args:
588-
ops: List of tf.Operation.
592+
ops: A list of tf.Operation's.
593+
input_boundary: A list of ops where traversal should terminate.
589594
"""
595+
if input_boundary:
596+
input_boundary = set(input_boundary)
597+
else:
598+
input_boundary = set()
590599
to_visit = list(ops)
591600
visited = set()
592601
while to_visit:
593602
# Get next op and mark as visited.
594603
op = to_visit.pop()
595604
visited.add(op)
605+
if op in input_boundary:
606+
continue
596607
self._all_ops.add(op)
597608

598609
# Check if op is a source by querying OpHandler.

morph_net/network_regularizers/activation_regularizer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def __init__(
2727
gamma_threshold,
2828
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
2929
decorator_parameters=None,
30+
input_boundary=None,
3031
force_group=None,
3132
regularizer_blacklist=None):
3233
"""Creates a GammaActivationRegularizer object.
@@ -42,6 +43,8 @@ def __init__(
4243
decorator_parameters: A dictionary of parameters to pass to the decorator
4344
factory. To be used only with decorators that requires parameters,
4445
otherwise use None.
46+
input_boundary: A list of ops that represent the input boundary of the
47+
subgraph being regularized (input boundary is not regularized).
4548
force_group: List of regex for ops that should be force-grouped. Each
4649
regex corresponds to a separate group. Use '|' operator to specify
4750
multiple patterns in a single regex. See op_regularizer_manager for more
@@ -63,6 +66,7 @@ def __init__(
6366
self._manager = orm.OpRegularizerManager(
6467
ops,
6568
op_handler_dict,
69+
input_boundary=input_boundary,
6670
force_group=force_group,
6771
regularizer_blacklist=regularizer_blacklist)
6872
self._calculator = cost_calculator.CostCalculator(
@@ -97,6 +101,7 @@ def __init__(
97101
l1_fraction=0,
98102
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
99103
decorator_parameters=None,
104+
input_boundary=None,
100105
force_group=None,
101106
regularizer_blacklist=None):
102107
"""Creates a GroupLassoActivationRegularizer object.
@@ -113,6 +118,8 @@ def __init__(
113118
decorator_parameters: A dictionary of parameters to pass to the decorator
114119
factory. To be used only with decorators that requires parameters,
115120
otherwise use None.
121+
input_boundary: A list of ops that represent the input boundary of the
122+
subgraph being regularized (input boundary is not regularized).
116123
force_group: List of regex for ops that should be force-grouped. Each
117124
regex corresponds to a separate group. Use '|' operator to specify
118125
multiple patterns in a single regex. See op_regularizer_manager for more
@@ -145,6 +152,7 @@ def __init__(
145152
self._manager = orm.OpRegularizerManager(
146153
ops,
147154
op_handler_dict,
155+
input_boundary=input_boundary,
148156
force_group=force_group,
149157
regularizer_blacklist=regularizer_blacklist)
150158
self._calculator = cost_calculator.CostCalculator(

morph_net/network_regularizers/flop_regularizer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,21 +27,24 @@ def __init__(
2727
gamma_threshold,
2828
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
2929
decorator_parameters=None,
30+
input_boundary=None,
3031
force_group=None,
3132
regularizer_blacklist=None):
3233
"""Creates a GammaFlopsRegularizer object.
3334
3435
Args:
3536
ops: A list of tf.Operation. An OpRegularizer will be created for all the
3637
ops in `ops`, and recursively for all ops they depend on via data
37-
dependency. Typically `ops` would contain a single tf.Operation, which
38-
is the output of the network.
38+
dependency that does not involve input ops. Typically `ops` would
39+
contain a single tf.Operation, which is the output of the network.
3940
gamma_threshold: A float scalar, will be used as a 'gamma_threshold' for
4041
all instances GammaL1Regularizer created by this class.
4142
regularizer_decorator: A class of OpRegularizer decorator to use.
4243
decorator_parameters: A dictionary of parameters to pass to the decorator
4344
factory. To be used only with decorators that requires parameters,
4445
otherwise use None.
46+
input_boundary: A list of ops that represent the input boundary of the
47+
subgraph being regularized (input boundary is not regularized).
4548
force_group: List of regex for ops that should be force-grouped. Each
4649
regex corresponds to a separate group. Use '|' operator to specify
4750
multiple patterns in a single regex. See op_regularizer_manager for more
@@ -63,6 +66,7 @@ def __init__(
6366
self._manager = orm.OpRegularizerManager(
6467
ops,
6568
op_handler_dict,
69+
input_boundary=input_boundary,
6670
force_group=force_group,
6771
regularizer_blacklist=regularizer_blacklist)
6872
self._calculator = cost_calculator.CostCalculator(
@@ -97,6 +101,7 @@ def __init__(
97101
l1_fraction=0,
98102
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
99103
decorator_parameters=None,
104+
input_boundary=None,
100105
force_group=None,
101106
regularizer_blacklist=None,
102107
convert_to_variable=True):
@@ -114,6 +119,8 @@ def __init__(
114119
decorator_parameters: A dictionary of parameters to pass to the decorator
115120
factory. To be used only with decorators that requires parameters,
116121
otherwise use None.
122+
input_boundary: A list of ops that represent the input boundary of the
123+
subgraph being regularized (input boundary is not regularized).
117124
force_group: List of regex for ops that should be force-grouped. Each
118125
regex corresponds to a separate group. Use '|' operator to specify
119126
multiple patterns in a single regex. See op_regularizer_manager for more
@@ -149,6 +156,7 @@ def __init__(
149156
self._manager = orm.OpRegularizerManager(
150157
ops,
151158
op_handler_dict,
159+
input_boundary=input_boundary,
152160
force_group=force_group,
153161
regularizer_blacklist=regularizer_blacklist)
154162
self._calculator = cost_calculator.CostCalculator(

morph_net/network_regularizers/flop_regularizer_test.py

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,18 +49,26 @@ def BuildModel(self):
4949
#
5050
# (the model has two "outputs", conv3 and conv4).
5151
#
52+
53+
# op.name: 'Const'
5254
image = tf.constant(0.0, shape=[1, 17, 19, NUM_CHANNELS])
53-
conv1 = slim.layers.conv2d(image, 13, [7, 5], padding='SAME', scope='conv1')
54-
conv2 = slim.layers.conv2d(image, 23, [1, 1], padding='SAME', scope='conv2')
55-
concat = tf.concat([conv1, conv2], 3)
55+
# op.name: 'conv1/Conv2D'
56+
self.conv1 = slim.layers.conv2d(
57+
image, 13, [7, 5], padding='SAME', scope='conv1')
58+
self.conv2 = slim.layers.conv2d(
59+
image, 23, [1, 1], padding='SAME', scope='conv2')
60+
self.concat = tf.concat([self.conv1, self.conv2], 3)
5661
self.conv3 = slim.layers.conv2d(
57-
concat, 29, [3, 3], stride=2, padding='SAME', scope='conv3')
62+
self.concat, 29, [3, 3], stride=2, padding='SAME', scope='conv3')
5863
self.conv4 = slim.layers.conv2d(
59-
concat, 31, [1, 1], stride=1, padding='SAME', scope='conv4')
64+
self.concat, 31, [1, 1], stride=1, padding='SAME', scope='conv4')
6065
self.name_to_var = {v.op.name: v for v in tf.global_variables()}
6166

67+
def AddRegularizer(self, input_boundary=None):
6268
self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
63-
[self.conv3.op, self.conv4.op], gamma_threshold=0.45)
69+
[self.conv3.op, self.conv4.op],
70+
gamma_threshold=0.45,
71+
input_boundary=input_boundary)
6472

6573
def GetConv(self, name):
6674
return tf.get_default_graph().get_operation_by_name(name + '/Conv2D')
@@ -84,8 +92,18 @@ def GetLoss(self, conv):
8492
with self.cached_session():
8593
return self.gamma_flop_reg.get_regularization_term(conv).eval()
8694

87-
def testCost(self,):
95+
def GetSourceOps(self):
96+
op_regularizer_manager = self.gamma_flop_reg.op_regularizer_manager
97+
return [
98+
op.name
99+
for op in op_regularizer_manager.ops
100+
if op_regularizer_manager.is_source_op(op)
101+
]
102+
103+
def testCost(self):
88104
self.BuildWithBatchNorm(fused=True)
105+
self.AddRegularizer(input_boundary=None)
106+
89107
# Conv1 has 7 gammas above 0.45, and NUM_CHANNELS inputs (from the image).
90108
conv = self.GetConv('conv1')
91109
self.assertEqual(_coeff(conv) * 7 * NUM_CHANNELS, self.GetCost([conv]))
@@ -107,8 +125,40 @@ def testCost(self,):
107125
self.assertEqual(
108126
self.GetCost(convs[:1]) + self.GetCost(convs[1:]), self.GetCost(convs))
109127

128+
def testInputBoundaryNone(self):
129+
self.BuildWithBatchNorm(fused=True)
130+
self.AddRegularizer(input_boundary=None)
131+
self.assertCountEqual(self.GetSourceOps(), [
132+
'conv1/BatchNorm/FusedBatchNorm', 'conv2/BatchNorm/FusedBatchNorm',
133+
'conv3/BatchNorm/FusedBatchNorm', 'conv4/BatchNorm/FusedBatchNorm'
134+
])
135+
136+
def testInputBoundaryConv3(self):
137+
# Only block one path, can still reach all other convolutions.
138+
self.BuildWithBatchNorm(fused=True)
139+
self.AddRegularizer(input_boundary=[self.conv3.op])
140+
self.assertCountEqual(self.GetSourceOps(), [
141+
'conv1/BatchNorm/FusedBatchNorm', 'conv2/BatchNorm/FusedBatchNorm',
142+
'conv4/BatchNorm/FusedBatchNorm'
143+
])
144+
145+
def testInputBoundaryConv3And4(self):
146+
# Block both paths, can no longer reach Concat and earlier convolutions.
147+
self.BuildWithBatchNorm(fused=True)
148+
self.AddRegularizer(input_boundary=[self.conv3.op, self.conv4.op])
149+
self.assertCountEqual(self.GetSourceOps(), [])
150+
151+
def testInputBoundaryConcat(self):
152+
# Block concat, can only see conv3 and conv4.
153+
self.BuildWithBatchNorm(fused=True)
154+
self.AddRegularizer(input_boundary=[self.concat.op])
155+
self.assertCountEqual(
156+
self.GetSourceOps(),
157+
['conv3/BatchNorm/FusedBatchNorm', 'conv4/BatchNorm/FusedBatchNorm'])
158+
110159
def testLossDecorated(self):
111160
self.BuildWithBatchNorm(True)
161+
self.AddRegularizer()
112162
# Create network regularizer with DummyDecorator op regularization.
113163
self.gamma_flop_reg = flop_regularizer.GammaFlopsRegularizer(
114164
[self.conv3.op, self.conv4.op],

morph_net/network_regularizers/latency_regularizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(
2626
batch_size=1,
2727
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
2828
decorator_parameters=None,
29+
input_boundary=None,
2930
force_group=None,
3031
regularizer_blacklist=None) -> None:
3132
"""Creates a GammaLatencyRegularizer object.
@@ -49,6 +50,8 @@ def __init__(
4950
decorator_parameters: A dictionary of parameters to pass to the decorator
5051
factory. To be used only with decorators that requires parameters,
5152
otherwise use None.
53+
input_boundary: A list of ops that represent the input boundary of the
54+
subgraph being regularized (input boundary is not regularized).
5255
force_group: List of regex for ops that should be force-grouped. Each
5356
regex corresponds to a separate group. Use '|' operator to specify
5457
multiple patterns in a single regex. See op_regularizer_manager for
@@ -69,7 +72,7 @@ def __init__(
6972
})
7073

7174
self._manager = orm.OpRegularizerManager(
72-
ops, op_handler_dict,
75+
ops, op_handler_dict, input_boundary=input_boundary,
7376
force_group=force_group, regularizer_blacklist=regularizer_blacklist)
7477
self._calculator = cost_calculator.CostCalculator(
7578
self._manager,

morph_net/network_regularizers/model_size_regularizer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
gamma_threshold,
2525
regularizer_decorator: Type[generic_regularizers.OpRegularizer] = None,
2626
decorator_parameters=None,
27+
input_boundary=None,
2728
force_group=None,
2829
regularizer_blacklist=None):
2930
"""Creates a GammaModelSizeRegularizer object.
@@ -41,6 +42,8 @@ def __init__(
4142
decorator_parameters: A dictionary of parameters to pass to the decorator
4243
factory. To be used only with decorators that requires parameters,
4344
otherwise use None.
45+
input_boundary: A list of ops that represent the input boundary of the
46+
subgraph being regularized (input boundary is not regularized).
4447
force_group: List of regex for ops that should be force-grouped. Each
4548
regex corresponds to a separate group. Use '|' operator to specify
4649
multiple patterns in a single regex. See op_regularizer_manager for
@@ -60,7 +63,7 @@ def __init__(
6063
})
6164

6265
self._manager = orm.OpRegularizerManager(
63-
ops, op_handler_dict,
66+
ops, op_handler_dict, input_boundary=input_boundary,
6467
force_group=force_group, regularizer_blacklist=regularizer_blacklist)
6568
self._calculator = cost_calculator.CostCalculator(
6669
self._manager, resource_function.model_size_function)

0 commit comments

Comments
 (0)