@@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):
108
108
109
109
def batch_normalization (x , mean , variance , offset , scale , variance_epsilon , data_format , name = None ):
110
110
"""Data Format aware version of tf.nn.batch_normalization."""
111
+ if data_format == 'channels_last' :
112
+ mean = tf .reshape (mean , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
113
+ variance = tf .reshape (variance , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
114
+ offset = tf .reshape (offset , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
115
+ scale = tf .reshape (scale , [1 ] * (len (x .shape ) - 1 ) + [- 1 ])
116
+ elif data_format == 'channels_first' :
117
+ mean = tf .reshape (mean , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
118
+ variance = tf .reshape (variance , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
119
+ offset = tf .reshape (offset , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
120
+ scale = tf .reshape (scale , [1 ] + [- 1 ] + [1 ] * (len (x .shape ) - 2 ))
121
+ else :
122
+ raise ValueError ('invalid data_format: %s' % data_format )
123
+
111
124
with ops .name_scope (name , 'batchnorm' , [x , mean , variance , scale , offset ]):
112
125
inv = math_ops .rsqrt (variance + variance_epsilon )
113
126
if scale is not None :
@@ -204,6 +217,9 @@ def __init__(
204
217
self .moving_var_init = moving_var_init
205
218
self .num_features = num_features
206
219
220
+ self .channel_axis = - 1 if data_format == 'channels_last' else 1
221
+ self .axes = None
222
+
207
223
if num_features is not None :
208
224
if not isinstance (self , BatchNorm1d ) and not isinstance (self , BatchNorm2d ) and not isinstance (self ,
209
225
BatchNorm3d ):
@@ -233,21 +249,23 @@ def __repr__(self):
233
249
234
250
def _get_param_shape (self , inputs_shape ):
235
251
if self .data_format == 'channels_last' :
236
- axis = len ( inputs_shape ) - 1
252
+ axis = - 1
237
253
elif self .data_format == 'channels_first' :
238
254
axis = 1
239
255
else :
240
256
raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
241
257
242
258
channels = inputs_shape [axis ]
243
- params_shape = [1 ] * len (inputs_shape )
244
- params_shape [axis ] = channels
259
+ params_shape = [channels ]
245
260
246
- axes = [i for i in range (len (inputs_shape )) if i != axis ]
247
- return params_shape , axes
261
+ return params_shape
262
+
263
+ def _check_input_shape (self , inputs ):
264
+ if inputs .ndim <= 1 :
265
+ raise ValueError ('expected input at least 2D, but got {}D input' .format (inputs .ndim ))
248
266
249
267
def build (self , inputs_shape ):
250
- params_shape , self .axes = self ._get_param_shape (inputs_shape )
268
+ params_shape = [ self .num_features ] if self . num_features is not None else self ._get_param_shape (inputs_shape )
251
269
252
270
self .beta , self .gamma = None , None
253
271
if self .beta_init :
@@ -264,7 +282,12 @@ def build(self, inputs_shape):
264
282
)
265
283
266
284
def forward (self , inputs ):
267
- mean , var = tf .nn .moments (inputs , self .axes , keepdims = True )
285
+ self ._check_input_shape (inputs )
286
+
287
+ if self .axes is None :
288
+ self .axes = [i for i in range (len (inputs .shape )) if i != self .channel_axis ]
289
+
290
+ mean , var = tf .nn .moments (inputs , self .axes , keepdims = False )
268
291
if self .is_train :
269
292
# update moving_mean and moving_var
270
293
self .moving_mean = moving_averages .assign_moving_average (
@@ -282,8 +305,8 @@ def forward(self, inputs):
282
305
283
306
284
307
class BatchNorm1d (BatchNorm ):
285
- """The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
286
- inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
308
+ """The :class:`BatchNorm1d` applies Batch Normalization over 2D/ 3D input (a mini-batch of 1D
309
+ inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
287
310
See more details in :class:`BatchNorm`.
288
311
289
312
Examples
@@ -298,24 +321,9 @@ class BatchNorm1d(BatchNorm):
298
321
>>> bn = tl.layers.BatchNorm1d(num_features=32)
299
322
300
323
"""
301
-
302
- def _get_param_shape (self , inputs_shape ):
303
- if self .data_format == 'channels_last' :
304
- axis = 2
305
- elif self .data_format == 'channels_first' :
306
- axis = 1
307
- else :
308
- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
309
-
310
- if self .num_features is None :
311
- channels = inputs_shape [axis ]
312
- else :
313
- channels = self .num_features
314
- params_shape = [1 ] * 3
315
- params_shape [axis ] = channels
316
-
317
- axes = [i for i in range (3 ) if i != axis ]
318
- return params_shape , axes
324
+ def _check_input_shape (self , inputs ):
325
+ if inputs .ndim != 2 and inputs .ndim != 3 :
326
+ raise ValueError ('expected input to be 2D or 3D, but got {}D input' .format (inputs .ndim ))
319
327
320
328
321
329
class BatchNorm2d (BatchNorm ):
@@ -335,24 +343,9 @@ class BatchNorm2d(BatchNorm):
335
343
>>> bn = tl.layers.BatchNorm2d(num_features=32)
336
344
337
345
"""
338
-
339
- def _get_param_shape (self , inputs_shape ):
340
- if self .data_format == 'channels_last' :
341
- axis = 3
342
- elif self .data_format == 'channels_first' :
343
- axis = 1
344
- else :
345
- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
346
-
347
- if self .num_features is None :
348
- channels = inputs_shape [axis ]
349
- else :
350
- channels = self .num_features
351
- params_shape = [1 ] * 4
352
- params_shape [axis ] = channels
353
-
354
- axes = [i for i in range (4 ) if i != axis ]
355
- return params_shape , axes
346
+ def _check_input_shape (self , inputs ):
347
+ if inputs .ndim != 4 :
348
+ raise ValueError ('expected input to be 4D, but got {}D input' .format (inputs .ndim ))
356
349
357
350
358
351
class BatchNorm3d (BatchNorm ):
@@ -372,24 +365,9 @@ class BatchNorm3d(BatchNorm):
372
365
>>> bn = tl.layers.BatchNorm3d(num_features=32)
373
366
374
367
"""
375
-
376
- def _get_param_shape (self , inputs_shape ):
377
- if self .data_format == 'channels_last' :
378
- axis = 4
379
- elif self .data_format == 'channels_first' :
380
- axis = 1
381
- else :
382
- raise ValueError ('data_format should be either %s or %s' % ('channels_last' , 'channels_first' ))
383
-
384
- if self .num_features is None :
385
- channels = inputs_shape [axis ]
386
- else :
387
- channels = self .num_features
388
- params_shape = [1 ] * 5
389
- params_shape [axis ] = channels
390
-
391
- axes = [i for i in range (5 ) if i != axis ]
392
- return params_shape , axes
368
+ def _check_input_shape (self , inputs ):
369
+ if inputs .ndim != 5 :
370
+ raise ValueError ('expected input to be 5D, but got {}D input' .format (inputs .ndim ))
393
371
394
372
395
373
class InstanceNorm (Layer ):
0 commit comments