@@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):
108
108
109
109
def batch_normalization (x , mean , variance , offset , scale , variance_epsilon , data_format , name = None ):
110
110
"""Data Format aware version of tf.nn.batch_normalization."""
111
+ if data_format == 'channels_last' :
112
+ mean = tf .reshape (mean , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
113
+ variance = tf .reshape (variance , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
114
+ offset = tf .reshape (offset , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
115
+ scale = tf .reshape (scale , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
116
+ elif data_format == 'channels_first' :
117
+ mean = tf .reshape (mean , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
118
+ variance = tf .reshape (variance , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
119
+ offset = tf .reshape (offset , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
120
+ scale = tf .reshape (scale , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
121
+ else :
122
+ raise ValueError ('invalid data_format: %s' % data_format )
123
+
111
124
with ops .name_scope (name , 'batchnorm' , [x , mean , variance , scale , offset ]):
112
125
inv = math_ops .rsqrt (variance + variance_epsilon )
113
126
if scale is not None :
@@ -204,13 +217,10 @@ def __init__(
204
217
self .moving_var_init = moving_var_init
205
218
self .num_features = num_features
206
219
220
+ self .channel_axis = - 1 if data_format == 'channels_last' else 1
221
+ self .axes = None
222
+
207
223
if num_features is not None :
208
- if not isinstance (self , BatchNorm1d ) and not isinstance (self , BatchNorm2d ) and not isinstance (self ,
209
- BatchNorm3d ):
210
- raise ValueError (
211
- "Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm "
212
- "if you want to specify 'num_features'."
213
- )
214
224
self .build (None )
215
225
self ._built = True
216
226
@@ -233,21 +243,23 @@ def __repr__(self):
233
243
234
244
def _get_param_shape (self , inputs_shape ):
235
245
if self .data_format == 'channels_last' :
236
- axis = len ( inputs_shape ) - 1
246
+ axis = - 1
237
247
elif self .data_format == 'channels_first' :
238
248
axis = 1
239
249
else :
240
250
raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
241
251
242
252
channels = inputs_shape [axis ]
243
- params_shape = [1 ] * len (inputs_shape )
244
- params_shape [axis ] = channels
253
+ params_shape = [channels ]
245
254
246
- axes = [i for i in range (len (inputs_shape )) if i != axis ]
247
- return params_shape , axes
255
+ return params_shape
256
+
257
+ def _check_input_shape (self , inputs ):
258
+ if inputs .ndim <= 1 :
259
+ raise ValueError ('expected input at least 2D, but got {}D input' .format (inputs .ndim ))
248
260
249
261
def build (self , inputs_shape ):
250
- params_shape , self .axes = self ._get_param_shape (inputs_shape )
262
+ params_shape = [ self .num_features ] if self . num_features is not None else self ._get_param_shape (inputs_shape )
251
263
252
264
self .beta , self .gamma = None , None
253
265
if self .beta_init :
@@ -264,7 +276,12 @@ def build(self, inputs_shape):
264
276
)
265
277
266
278
def forward (self , inputs ):
267
- mean , var = tf .nn .moments (inputs , self .axes , keepdims = True )
279
+ self ._check_input_shape (inputs )
280
+
281
+ if self .axes is None :
282
+ self .axes = [i for i in range (len (inputs .shape )) if i != self .channel_axis ]
283
+
284
+ mean , var = tf .nn .moments (inputs , self .axes , keepdims = False )
268
285
if self .is_train :
269
286
# update moving_mean and moving_var
270
287
self .moving_mean = moving_averages .assign_moving_average (
@@ -282,8 +299,8 @@ def forward(self, inputs):
282
299
283
300
284
301
class BatchNorm1d (BatchNorm ):
285
- """The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
286
- inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
302
+ """The :class:`BatchNorm1d` applies Batch Normalization over 2D/ 3D input (a mini-batch of 1D
303
+ inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
287
304
See more details in :class:`BatchNorm`.
288
305
289
306
Examples
@@ -299,23 +316,9 @@ class BatchNorm1d(BatchNorm):
299
316
300
317
"""
301
318
302
- def _get_param_shape (self , inputs_shape ):
303
- if self .data_format == 'channels_last' :
304
- axis = 2
305
- elif self .data_format == 'channels_first' :
306
- axis = 1
307
- else :
308
- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
309
-
310
- if self .num_features is None :
311
- channels = inputs_shape [axis ]
312
- else :
313
- channels = self .num_features
314
- params_shape = [1 ] * 3
315
- params_shape [axis ] = channels
316
-
317
- axes = [i for i in range (3 ) if i != axis ]
318
- return params_shape , axes
319
+ def _check_input_shape (self , inputs ):
320
+ if inputs .ndim != 2 and inputs .ndim != 3 :
321
+ raise ValueError ('expected input to be 2D or 3D, but got {}D input' .format (inputs .ndim ))
319
322
320
323
321
324
class BatchNorm2d (BatchNorm ):
@@ -336,23 +339,9 @@ class BatchNorm2d(BatchNorm):
336
339
337
340
"""
338
341
339
- def _get_param_shape (self , inputs_shape ):
340
- if self .data_format == 'channels_last' :
341
- axis = 3
342
- elif self .data_format == 'channels_first' :
343
- axis = 1
344
- else :
345
- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
346
-
347
- if self .num_features is None :
348
- channels = inputs_shape [axis ]
349
- else :
350
- channels = self .num_features
351
- params_shape = [1 ] * 4
352
- params_shape [axis ] = channels
353
-
354
- axes = [i for i in range (4 ) if i != axis ]
355
- return params_shape , axes
342
+ def _check_input_shape (self , inputs ):
343
+ if inputs .ndim != 4 :
344
+ raise ValueError ('expected input to be 4D, but got {}D input' .format (inputs .ndim ))
356
345
357
346
358
347
class BatchNorm3d (BatchNorm ):
@@ -373,23 +362,9 @@ class BatchNorm3d(BatchNorm):
373
362
374
363
"""
375
364
376
- def _get_param_shape (self , inputs_shape ):
377
- if self .data_format == 'channels_last' :
378
- axis = 4
379
- elif self .data_format == 'channels_first' :
380
- axis = 1
381
- else :
382
- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
383
-
384
- if self .num_features is None :
385
- channels = inputs_shape [axis ]
386
- else :
387
- channels = self .num_features
388
- params_shape = [1 ] * 5
389
- params_shape [axis ] = channels
390
-
391
- axes = [i for i in range (5 ) if i != axis ]
392
- return params_shape , axes
365
+ def _check_input_shape (self , inputs ):
366
+ if inputs .ndim != 5 :
367
+ raise ValueError ('expected input to be 5D, but got {}D input' .format (inputs .ndim ))
393
368
394
369
395
370
class InstanceNorm (Layer ):
0 commit comments