4141def batch_norm_relu (inputs , is_training , data_format ):
4242 """Performs a batch normalization followed by a ReLU."""
4343 # We set fused=True for a significant performance boost.
44+ # See https://www.tensorflow.org/performance/performance_guide#common_fused_ops
4445 inputs = tf .layers .batch_normalization (
4546 inputs = inputs , axis = 1 if data_format == 'channels_first' else 3 ,
4647 momentum = _BATCH_NORM_DECAY , epsilon = _BATCH_NORM_EPSILON , center = True ,
@@ -240,6 +241,7 @@ def model(inputs, is_training):
240241 if data_format == 'channels_first' :
241242 # Convert from channels_last (NHWC) to channels_first (NCHW). This
242243 # provides a large performance boost on GPU.
244+ # See https://www.tensorflow.org/performance/performance_guide#data_formats
243245 inputs = tf .transpose (inputs , [0 , 3 , 1 , 2 ])
244246
245247 inputs = conv2d_fixed_padding (
@@ -261,14 +263,12 @@ def model(inputs, is_training):
261263 data_format = data_format )
262264
263265 inputs = batch_norm_relu (inputs , is_training , data_format )
264-
265266 inputs = tf .layers .average_pooling2d (
266267 inputs = inputs , pool_size = 8 , strides = 1 , padding = 'VALID' ,
267268 data_format = data_format )
268269 inputs = tf .identity (inputs , 'final_avg_pool' )
269270 inputs = tf .reshape (inputs , [- 1 , 64 ])
270- inputs = tf .layers .dense (
271- inputs = inputs , units = num_classes )
271+ inputs = tf .layers .dense (inputs = inputs , units = num_classes )
272272 inputs = tf .identity (inputs , 'final_dense' )
273273 return inputs
274274
0 commit comments