Skip to content

Commit f67a608

Browse files
eladebanmn-robot
authored andcommitted
Support tf.concat on axis different than channels (-1 or rank -1).
PiperOrigin-RevId: 248360148
1 parent 5bc73b2 commit f67a608

File tree

4 files changed

+377
-21
lines changed

4 files changed

+377
-21
lines changed

morph_net/framework/concat_op_handler.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44
from __future__ import division
55
from __future__ import print_function
66

7+
from morph_net.framework import grouping_op_handler
78
from morph_net.framework import op_handler
89
from morph_net.framework import op_handler_util
910

1011

12+
# The axis arg of tf.concat is a constant tensor stored in the last element of
13+
# op.inputs. This function access the value of that tensor.
14+
def _get_concat_op_axis(op):
15+
return op.inputs[-1].op.get_attr('value').int_val[0]
16+
17+
1118
class ConcatOpHandler(op_handler.OpHandler):
1219
"""OpHandler implementation for concat operations."""
1320

@@ -26,6 +33,17 @@ def assign_grouping(self, op, op_reg_manager):
2633
op: tf.Operation to assign grouping to.
2734
op_reg_manager: OpRegularizerManager to keep track of the grouping.
2835
"""
36+
concat_axis = _get_concat_op_axis(op)
37+
# Need to figure out the rank to know if axis is last.
38+
rank = len(op.inputs[0].shape) # Rank of the first input.
39+
40+
if concat_axis != -1 and concat_axis != rank - 1:
41+
# Concat is actually grouping inputs!
42+
handler = grouping_op_handler.GroupingOpHandler()
43+
handler.assign_grouping(op, op_reg_manager)
44+
return
45+
46+
# If concat is of the last dimension, this is a `standard` concat.
2947
# TODO(a1): Consider refactoring this duplicated logic.
3048
# Check if all input ops have groups, or tell the manager to process them.
3149
input_ops = op_handler_util.get_input_ops(op, op_reg_manager)

morph_net/framework/concat_op_handler_test.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import tensorflow as tf
1111

1212
layers = tf.contrib.layers
13+
arg_scope = tf.contrib.framework.arg_scope
1314

1415

1516
class ConcatOpHandlerTest(tf.test.TestCase):
@@ -757,5 +758,157 @@ def testGetInputOutputOpSlices(self):
757758
self.mock_op_reg_manager))
758759

759760

761+
class GroupingConcatOpHandlerTest(tf.test.TestCase):
762+
763+
def _get_scope(self):
764+
params = {
765+
'trainable': True,
766+
'normalizer_fn': layers.batch_norm,
767+
'normalizer_params': {
768+
'scale': True,
769+
},
770+
}
771+
772+
with arg_scope([layers.conv2d], **params) as sc:
773+
return sc
774+
775+
def setUp(self):
776+
tf.reset_default_graph()
777+
778+
# This tests 3 Conv2D ops being concatenated.
779+
inputs = tf.zeros([2, 4, 4, 3])
780+
with tf.contrib.framework.arg_scope(self._get_scope()):
781+
c1 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv1')
782+
c2 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv2')
783+
c3 = layers.conv2d(inputs, num_outputs=6, kernel_size=3, scope='conv3')
784+
net = tf.concat([c1, c2, c3], axis=2)
785+
layers.batch_norm(net)
786+
787+
g = tf.get_default_graph()
788+
789+
# Declare OpSlice and OpGroup for ops of interest.
790+
self.concat_op = g.get_operation_by_name('concat')
791+
self.concat_op_slice = orm.OpSlice(self.concat_op, orm.Slice(0, 6))
792+
self.concat_op_group = orm.OpGroup(
793+
self.concat_op_slice,
794+
omit_source_op_slices=[self.concat_op_slice])
795+
796+
self.relu1_op = g.get_operation_by_name('conv1/Relu')
797+
self.relu1_op_slice = orm.OpSlice(self.relu1_op, orm.Slice(0, 6))
798+
self.relu1_op_group = orm.OpGroup(
799+
self.relu1_op_slice, omit_source_op_slices=[self.relu1_op_slice])
800+
801+
self.relu2_op = g.get_operation_by_name('conv2/Relu')
802+
self.relu2_op_slice = orm.OpSlice(self.relu2_op, orm.Slice(0, 6))
803+
self.relu2_op_group = orm.OpGroup(
804+
self.relu2_op_slice, omit_source_op_slices=[self.relu2_op_slice])
805+
806+
self.relu3_op = g.get_operation_by_name('conv3/Relu')
807+
self.relu3_op_slice = orm.OpSlice(self.relu3_op, orm.Slice(0, 6))
808+
self.relu3_op_group = orm.OpGroup(
809+
self.relu3_op_slice, omit_source_op_slices=[self.relu3_op_slice])
810+
811+
self.batch_norm_op = g.get_operation_by_name('BatchNorm/FusedBatchNorm')
812+
self.batch_norm_op_slice = orm.OpSlice(self.batch_norm_op, orm.Slice(0, 6))
813+
self.batch_norm_op_group = orm.OpGroup(
814+
self.batch_norm_op_slice,
815+
omit_source_op_slices=[self.batch_norm_op_slice])
816+
817+
self.concat_group = orm.OpGroup(
818+
op_slice=None,
819+
op_groups=[
820+
self.batch_norm_op_group, self.concat_op_group, self.relu1_op_group,
821+
self.relu2_op_group, self.relu3_op_group
822+
])
823+
824+
# Create mock OpRegularizerManager with custom mapping of OpSlice and
825+
# OpGroup.
826+
self.mock_op_reg_manager = mock.create_autospec(orm.OpRegularizerManager)
827+
828+
def get_op_slices(op):
829+
return self.op_slice_dict.get(op, [])
830+
831+
def get_op_group(op_slice):
832+
return self.op_group_dict.get(op_slice)
833+
834+
self.mock_op_reg_manager.get_op_slices.side_effect = get_op_slices
835+
self.mock_op_reg_manager.get_op_group.side_effect = get_op_group
836+
self.mock_op_reg_manager.is_source_op.return_value = False
837+
self.mock_op_reg_manager.is_passthrough.return_value = True
838+
self.mock_op_reg_manager.ops = [
839+
self.concat_op, self.relu1_op, self.relu2_op, self.relu3_op,
840+
self.batch_norm_op]
841+
842+
def test_AssignGroupingOfGroupingConcatNoSlicing(self):
843+
# In this test, the output op (batch norm) has size 6 and is not sliced.
844+
# and that input Conv2Ds are all of size 6, and are grouped.
845+
846+
# Map ops to slices. Batch norm op is composed of multiple slices.
847+
self.op_slice_dict = {
848+
self.relu1_op: [self.relu1_op_slice],
849+
self.relu2_op: [self.relu2_op_slice],
850+
self.relu3_op: [self.relu3_op_slice],
851+
self.concat_op: [self.concat_op_slice],
852+
self.batch_norm_op: [self.batch_norm_op_slice],
853+
}
854+
855+
# Map each slice to a group.
856+
self.op_group_dict = {
857+
self.relu1_op_slice: self.relu1_op_group,
858+
self.relu2_op_slice: self.relu2_op_group,
859+
self.relu3_op_slice: self.relu3_op_group,
860+
self.batch_norm_op_slice: self.batch_norm_op_group
861+
}
862+
863+
# Call handler to assign grouping.
864+
handler = concat_op_handler.ConcatOpHandler()
865+
handler.assign_grouping(self.concat_op, self.mock_op_reg_manager)
866+
867+
# Verify manager looks up OpSlice for ops of interest.
868+
self.mock_op_reg_manager.get_op_slices.assert_has_calls(
869+
# Checking for ops to process.
870+
[mock.call(self.relu1_op),
871+
mock.call(self.relu2_op),
872+
mock.call(self.relu3_op),
873+
mock.call(self.batch_norm_op),
874+
# Initial slice data.
875+
mock.call(self.concat_op),
876+
mock.call(self.relu1_op),
877+
mock.call(self.relu2_op),
878+
mock.call(self.relu3_op),
879+
mock.call(self.batch_norm_op),
880+
# Reslicing.
881+
mock.call(self.relu1_op),
882+
mock.call(self.relu2_op),
883+
mock.call(self.relu3_op),
884+
mock.call(self.concat_op),
885+
mock.call(self.batch_norm_op),
886+
# Refreshing slice data.
887+
mock.call(self.relu1_op),
888+
mock.call(self.relu2_op),
889+
mock.call(self.relu3_op),
890+
mock.call(self.batch_norm_op),
891+
# Group concat op.
892+
mock.call(self.concat_op)])
893+
894+
# Verify manager does not slices the concat op.
895+
self.mock_op_reg_manager.slice_op.assert_not_called()
896+
897+
# Verify manager groups the new slices.
898+
self.mock_op_reg_manager.group_op_slices.assert_called_once_with([
899+
self.concat_op_slice, self.relu1_op_slice, self.relu2_op_slice,
900+
self.relu3_op_slice, self.batch_norm_op_slice
901+
])
902+
903+
def testGetConcatOpAxis(self):
904+
x = tf.zeros([7, 12, 12, 3])
905+
self.assertEqual(
906+
concat_op_handler._get_concat_op_axis(tf.concat([x, x], 3).op), 3)
907+
self.assertEqual(
908+
concat_op_handler._get_concat_op_axis(tf.concat([x, x, x], 1).op), 1)
909+
self.assertEqual(
910+
concat_op_handler._get_concat_op_axis(tf.concat([x, x, x], 2).op), 2)
911+
912+
760913
if __name__ == '__main__':
761914
tf.test.main()

0 commit comments

Comments
 (0)