Skip to content

Commit b9cc6d8

Browse files
eladebanmn-robot
authored andcommitted
Adds low level support for Conv3D.
Note that 3D convolutions are not yet supported by network_regularizer. PiperOrigin-RevId: 248762137
1 parent 537bb9d commit b9cc6d8

File tree

4 files changed

+101
-63
lines changed

4 files changed

+101
-63
lines changed

morph_net/network_regularizers/cost_calculator.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99

1010
CONV2D_OPS = ('Conv2D', 'Conv2DBackpropInput', 'DepthwiseConv2dNative')
11-
FLOP_OPS = CONV2D_OPS + ('MatMul',)
11+
CONV3D_OPS = ('Conv3D',)
12+
CONV_OPS = CONV2D_OPS + CONV3D_OPS
13+
FLOP_OPS = CONV_OPS + ('MatMul',)
1214
SUPPORTED_OPS = FLOP_OPS + (
1315
'Add', 'AddN', 'ConcatV2', 'FusedBatchNorm', 'Mul', 'Relu', 'Relu6', 'Sum')
1416

@@ -60,7 +62,7 @@ def _get_cost_or_regularization_term(self, is_regularization, ops=None):
6062
continue
6163

6264
# Get regularization and alive terms for input and output.
63-
input_tensor = _get_input(op)
65+
input_tensor = get_input_activation(op)
6466
if op.type == 'ConcatV2':
6567
# For concat, the input and output regularization are identical but the
6668
# input is composed of multiple concatenated regularizers. Thus, just
@@ -110,8 +112,8 @@ def get_regularization_term(self, ops=None):
110112
return self._get_cost_or_regularization_term(True, ops)
111113

112114

113-
def _get_input(op):
114-
"""Returns the input to that op that represents the activations.
115+
def get_input_activation(op):
116+
"""Returns the input to `op` that represents the activations.
115117
116118
(as opposed to e.g. weights.)
117119
@@ -122,10 +124,12 @@ def _get_input(op):
122124
A tf.Tensor representing the input activations.
123125
124126
Raises:
127+
ValueError: op type not supported.).
125128
ValueError: MatMul is used with transposition (unsupported).
126129
"""
127-
assert op.type in SUPPORTED_OPS, 'Op type %s is not supported.' % op.type
128-
if op.type == 'Conv2D' or op.type == 'DepthwiseConv2dNative':
130+
if op.type not in SUPPORTED_OPS:
131+
raise ValueError('Op type %s is not supported.' % op.type)
132+
if op.type in ('Conv3D', 'Conv2D', 'DepthwiseConv2dNative'):
129133
return op.inputs[0]
130134
if op.type == 'Conv2DBackpropInput':
131135
return op.inputs[2]

morph_net/network_regularizers/cost_calculator_test.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from __future__ import print_function
66

77
import collections
8+
from absl.testing import parameterized
89
from morph_net.framework import batch_norm_source_op_handler
910
from morph_net.framework import concat_op_handler
1011
from morph_net.framework import grouping_op_handler
1112
from morph_net.framework import op_regularizer_manager as orm
1213
from morph_net.framework import output_non_passthrough_op_handler
13-
from morph_net.network_regularizers import cost_calculator
14+
from morph_net.network_regularizers import cost_calculator as cc
1415
from morph_net.network_regularizers import resource_function
1516
from morph_net.testing import add_concat_model_stub
1617
import tensorflow as tf
@@ -19,7 +20,7 @@
1920
layers = tf.contrib.layers
2021

2122

22-
class NetworkRegularizerTest(tf.test.TestCase):
23+
class CostCalculatorTest(parameterized.TestCase, tf.test.TestCase):
2324

2425
def _batch_norm_scope(self):
2526
params = {
@@ -70,8 +71,7 @@ def testImageIsNotZerothOutputOfOp(self):
7071

7172
# Create OpRegularizerManager and NetworkRegularizer for test.
7273
manager = orm.OpRegularizerManager([output_op], op_handler_dict)
73-
calculator = cost_calculator.CostCalculator(
74-
manager, resource_function.flop_function)
74+
calculator = cc.CostCalculator(manager, resource_function.flop_function)
7575

7676
# Calculate expected FLOP cost.
7777
expected_alive_conv1 = sum(add_concat_model_stub.expected_alive()['conv1'])
@@ -92,6 +92,32 @@ def testImageIsNotZerothOutputOfOp(self):
9292
queue.enqueue((non_image_tensor, image)).run()
9393
self.assertEqual(expected_cost,
9494
calculator.get_cost([conv1_op]).eval())
95+
# for 0/1 assigments cost and reg_term are equal:
96+
self.assertEqual(expected_cost,
97+
calculator.get_regularization_term([conv1_op]).eval())
98+
99+
@parameterized.named_parameters(
100+
('_conv2d', 4, lambda x: layers.conv2d(x, 16, 3), 'Conv2D'),
101+
('_convt', 4, lambda x: layers.conv2d_transpose(x, 16, 3),
102+
'conv2d_transpose'),
103+
('_conv2s', 4, lambda x: layers.separable_conv2d(x, None, 3),
104+
'depthwise'),
105+
('_conv3d', 5, lambda x: layers.conv3d(x, 16, 3), 'Conv3D'))
106+
def test_get_input_activation2(self, rank, fn, op_name):
107+
g = tf.get_default_graph()
108+
inputs = tf.zeros([6] * rank)
109+
with arg_scope([
110+
layers.conv2d, layers.conv2d_transpose, layers.separable_conv2d,
111+
layers.conv3d
112+
],
113+
scope='test_layer'):
114+
_ = fn(inputs)
115+
for op in g.get_operations():
116+
print(op.name)
117+
self.assertEqual(
118+
inputs,
119+
cc.get_input_activation(
120+
g.get_operation_by_name('test_layer/' + op_name)))
95121

96122

97123
if __name__ == '__main__':

morph_net/network_regularizers/resource_function.py

Lines changed: 33 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from __future__ import print_function
66
from morph_net.framework import op_handler_util
77
from morph_net.network_regularizers import cost_calculator
8+
9+
import numpy as np
810
import tensorflow as tf
911

1012
# Data sheet for K80:
@@ -54,15 +56,17 @@ def flop_coeff(op):
5456
have one multiplication and one addition for each convolution weight and
5557
pixel. This function returns C.
5658
59+
Supported operations names are listed in cost_calculator.FLOP_OPS.
60+
5761
Args:
58-
op: A tf.Operation of type 'Conv2D' or 'MatMul'.
62+
op: A tf.Operation of supported types.
5963
6064
Returns:
6165
A float, the coefficient that when multiplied by the input depth and by the
6266
output depth gives the number of flops needed to compute the convolution.
6367
6468
Raises:
65-
ValueError: conv_op is not a tf.Operation of type Conv2D.
69+
ValueError: conv_op is not a supported tf.Operation.
6670
"""
6771
if not is_flop_op(op):
6872
return 0.0
@@ -72,24 +76,35 @@ def flop_coeff(op):
7276
return 2.0
7377
# Looking at the output shape makes it easy to automatically take into
7478
# account strides and the type of padding.
75-
if op.type == 'Conv2D' or op.type == 'DepthwiseConv2dNative':
76-
shape = op.outputs[0].shape.dims
77-
tensor_shape = tf.shape(op.outputs[0])
79+
def kernel_num_elements(tensor):
80+
"""Returns the number of elements of a kernel.
81+
82+
Args:
83+
tensor: The weight tensor.
84+
85+
Returns:
86+
Number of elements of the kernel (either float or tf.float).
87+
"""
88+
num_elements = np.prod(tensor.shape.dims[1:-1]).value
89+
if num_elements:
90+
return num_elements
91+
return tf.to_float(tf.reduce_prod(tf.shape(tensor)[1:-1]))
92+
93+
if op.type in ('Conv2D', 'DepthwiseConv2dNative', 'Conv3D'):
94+
num_elements = kernel_num_elements(op.outputs[0])
7895
elif op.type == 'Conv2DBackpropInput':
7996
# For a transposed convolution, the input and the output are swapped (as
8097
# far as shapes are concerned). In other words, for a given filter shape
8198
# and stride, if Conv2D maps from shapeX to shapeY, Conv2DBackpropInput
8299
# maps from shapeY to shapeX. Therefore wherever we use the output shape
83100
# for Conv2D, we use the input shape for Conv2DBackpropInput.
84-
input_tensor = _get_input(op)
85-
shape = input_tensor.shape.dims
86-
tensor_shape = tf.shape(input_tensor)
87-
101+
num_elements = kernel_num_elements(cost_calculator.get_input_activation(op))
102+
else:
103+
# Can only happen if elements are added to FLOP_OPS and not taken care of.
104+
assert False, '%s in cost_calculator.FLOP_OPS but not handled' % op.type
88105
# Handle dynamic shaping while keeping old code path to not break
89106
# other clients.
90-
size = shape[1] * shape[2]
91-
size = size.value or tf.to_float(tensor_shape[1] * tensor_shape[2])
92-
return 2.0 * size * _get_conv_filter_size(op)
107+
return 2.0 * num_elements * _get_conv_filter_size(op)
93108

94109

95110
def num_weights_coeff(op):
@@ -107,7 +122,7 @@ def num_weights_coeff(op):
107122
"""
108123
if not is_flop_op(op):
109124
return 0.0
110-
return (_get_conv_filter_size(op) if op.type in cost_calculator.CONV2D_OPS
125+
return (_get_conv_filter_size(op) if op.type in cost_calculator.CONV_OPS
111126
else 1.0)
112127

113128

@@ -420,37 +435,12 @@ def is_flop_op(op):
420435

421436

422437
def _get_conv_filter_size(conv_op):
423-
assert conv_op.type in cost_calculator.CONV2D_OPS
438+
# Works for 2D and 3D convs where sizes of weight matrix are:
439+
# 4D or 5D tensors: [kernel_size[:], inputs, outputs]
440+
assert conv_op.type in cost_calculator.CONV_OPS
424441
conv_weights = conv_op.inputs[1]
425-
filter_shape = conv_weights.shape.as_list()[:2]
426-
return filter_shape[0] * filter_shape[1]
427-
428-
429-
def _get_input(op):
430-
"""Returns the input to that op that represents the activations.
431-
432-
Specifically, return the activation tensor rather than the weight tensor.
433-
434-
Args:
435-
op: A tf.Operation object with type in _SUPPORTED_OPS.
436-
437-
Returns:
438-
A tf.Tensor representing the input activations.
439-
440-
Raises:
441-
ValueError: MatMul is used with transposition (unsupported).
442-
"""
443-
assert op.type in cost_calculator.SUPPORTED_OPS, (
444-
'Op type %s is not supported.' % op.type)
445-
if op.type == 'Conv2D' or op.type == 'DepthwiseConv2dNative':
446-
return op.inputs[0]
447-
if op.type == 'Conv2DBackpropInput':
448-
return op.inputs[2]
449-
if op.type == 'MatMul':
450-
if op.get_attr('transpose_a') or op.get_attr('transpose_b'):
451-
raise ValueError('MatMul with transposition is not yet supported.')
452-
return op.inputs[0]
453-
return op.inputs[0]
442+
filter_shape = conv_weights.shape.as_list()[:-2]
443+
return np.prod(filter_shape)
454444

455445

456446
def _calculate_bilinear_regularization(

morph_net/network_regularizers/resource_function_test.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,18 @@ def setUp(self):
3434
layers.separable_conv2d(
3535
net, None, [3, 2], depth_multiplier=1, padding='SAME', scope='dw1')
3636

37-
self.conv_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
38-
self.convt_op = tf.get_default_graph().get_operation_by_name(
37+
self.video_shape = (1, 11, 9, 13, 17)
38+
self.video = tf.placeholder(tf.float32, shape=[1, None, None, None, 17])
39+
net = layers.conv3d(
40+
self.video, 19, [7, 3, 5], stride=2, padding='SAME', scope='vconv1')
41+
g = tf.get_default_graph()
42+
self.conv_op = g.get_operation_by_name('conv1/Conv2D')
43+
self.convt_op = g.get_operation_by_name(
3944
'convt2/conv2d_transpose')
40-
self.matmul_op = tf.get_default_graph().get_operation_by_name(
41-
'FC/MatMul')
42-
self.dw_op = tf.get_default_graph().get_operation_by_name(
43-
'dw1/depthwise')
45+
self.matmul_op = g.get_operation_by_name('FC/MatMul')
46+
self.dw_op = g.get_operation_by_name('dw1/depthwise')
47+
self.conv3d_op = g.get_operation_by_name(
48+
'vconv1/Conv3D')
4449

4550
@parameterized.named_parameters(
4651
('_BatchSize1_AliveIn17_AliveOut19', 1, 17, 19),
@@ -1149,7 +1154,7 @@ def testBadHardware(self):
11491154
_ = resource_function.latency_function_factory(None, 11)
11501155

11511156
def testConvFlopsCoeff(self):
1152-
tf.reset_default_graph()
1157+
tf.compat.v1.reset_default_graph()
11531158
image = tf.constant(0.0, shape=[1, 11, 13, 17])
11541159
layers.conv2d(image, 19, [7, 5], stride=2, padding='SAME', scope='conv1')
11551160
conv_op = tf.get_default_graph().get_operation_by_name('conv1/Conv2D')
@@ -1159,7 +1164,7 @@ def testConvFlopsCoeff(self):
11591164
self.assertNearRelatively(expected_coeff, actual_coeff)
11601165

11611166
def testConvFlopsCoeffUnknownShape(self):
1162-
tf.reset_default_graph()
1167+
tf.compat.v1.reset_default_graph()
11631168
image = tf.placeholder(tf.float32, shape=[1, None, None, 17])
11641169
net = layers.conv2d(
11651170
image, 19, [7, 5], stride=2, padding='SAME', scope='conv1')
@@ -1176,7 +1181,7 @@ def testConvFlopsCoeffUnknownShape(self):
11761181
self.assertNearRelatively(expected_coeff, actual_coeff)
11771182

11781183
def testConvTransposeFlopsCoeff(self):
1179-
tf.reset_default_graph()
1184+
tf.compat.v1.reset_default_graph()
11801185
image = tf.constant(0.0, shape=[1, 11, 13, 17])
11811186
layers.conv2d_transpose(
11821187
image, 29, [7, 5], stride=2, padding='SAME', scope='convt2')
@@ -1204,7 +1209,7 @@ def testFcNumWeightsCoeff(self):
12041209
self.assertNearRelatively(1.0, actual_coeff)
12051210

12061211
def testDepthwiseConvFlopsCoeff(self):
1207-
tf.reset_default_graph()
1212+
tf.compat.v1.reset_default_graph()
12081213
image = tf.constant(0.0, shape=[1, 11, 13, 17])
12091214
net = layers.conv2d(
12101215
image, 10, [7, 5], stride=2, padding='SAME', scope='conv2')
@@ -1218,6 +1223,19 @@ def testDepthwiseConvFlopsCoeff(self):
12181223
actual_coeff = resource_function.flop_coeff(dw_op)
12191224
self.assertNearRelatively(expected_coeff, actual_coeff)
12201225

1226+
def test_conv3d_flops_coeff(self):
1227+
tf.compat.v1.reset_default_graph()
1228+
input_depth = 17
1229+
output_depth = 10
1230+
video = tf.zeros([1, 15, 12, 13, input_depth])
1231+
_ = layers.conv3d(
1232+
video, output_depth, [7, 5, 3], stride=2, padding='SAME', scope='conv')
1233+
conv_op = tf.get_default_graph().get_operation_by_name('conv/Conv3D')
1234+
# Divide by the input depth and the output depth to get the coefficient.
1235+
expected_coeff = _flops(conv_op) / (input_depth * output_depth)
1236+
actual_coeff = resource_function.flop_coeff(conv_op)
1237+
self.assertNearRelatively(expected_coeff, actual_coeff)
1238+
12211239

12221240
def _flops(op):
12231241
"""Get the number of flops of a convolution, from the ops stats registry.

0 commit comments

Comments
 (0)