Skip to content

Commit e41e44e

Browse files
authored
Refactor batch norm, BN1d support 2D inputs
1 parent b1cd59d commit e41e44e

File tree

1 file changed

+41
-63
lines changed

1 file changed

+41
-63
lines changed

tensorlayer/layers/normalization.py

Lines changed: 41 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):
108108

109109
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
110110
"""Data Format aware version of tf.nn.batch_normalization."""
111+
if data_format == 'channels_last':
112+
mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1])
113+
variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1])
114+
offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1])
115+
scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1])
116+
elif data_format == 'channels_first':
117+
mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2))
118+
variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2))
119+
offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2))
120+
scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2))
121+
else:
122+
raise ValueError('invalid data_format: %s' % data_format)
123+
111124
with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]):
112125
inv = math_ops.rsqrt(variance + variance_epsilon)
113126
if scale is not None:
@@ -204,6 +217,9 @@ def __init__(
204217
self.moving_var_init = moving_var_init
205218
self.num_features = num_features
206219

220+
self.channel_axis = -1 if data_format == 'channels_last' else 1
221+
self.axes = None
222+
207223
if num_features is not None:
208224
if not isinstance(self, BatchNorm1d) and not isinstance(self, BatchNorm2d) and not isinstance(self,
209225
BatchNorm3d):
@@ -233,21 +249,23 @@ def __repr__(self):
233249

234250
def _get_param_shape(self, inputs_shape):
235251
if self.data_format == 'channels_last':
236-
axis = len(inputs_shape) - 1
252+
axis = -1
237253
elif self.data_format == 'channels_first':
238254
axis = 1
239255
else:
240256
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
241257

242258
channels = inputs_shape[axis]
243-
params_shape = [1] * len(inputs_shape)
244-
params_shape[axis] = channels
259+
params_shape = [channels]
245260

246-
axes = [i for i in range(len(inputs_shape)) if i != axis]
247-
return params_shape, axes
261+
return params_shape
262+
263+
def _check_input_shape(self, inputs):
264+
if inputs.ndim <= 1:
265+
raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))
248266

249267
def build(self, inputs_shape):
250-
params_shape, self.axes = self._get_param_shape(inputs_shape)
268+
params_shape = [self.num_features] if self.num_features is not None else self._get_param_shape(inputs_shape)
251269

252270
self.beta, self.gamma = None, None
253271
if self.beta_init:
@@ -264,7 +282,12 @@ def build(self, inputs_shape):
264282
)
265283

266284
def forward(self, inputs):
267-
mean, var = tf.nn.moments(inputs, self.axes, keepdims=True)
285+
self._check_input_shape(inputs)
286+
287+
if self.axes is None:
288+
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
289+
290+
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
268291
if self.is_train:
269292
# update moving_mean and moving_var
270293
self.moving_mean = moving_averages.assign_moving_average(
@@ -282,8 +305,8 @@ def forward(self, inputs):
282305

283306

284307
class BatchNorm1d(BatchNorm):
285-
"""The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
286-
inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
308+
"""The :class:`BatchNorm1d` applies Batch Normalization over 2D/3D input (a mini-batch of 1D
309+
inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
287310
See more details in :class:`BatchNorm`.
288311
289312
Examples
@@ -298,24 +321,9 @@ class BatchNorm1d(BatchNorm):
298321
>>> bn = tl.layers.BatchNorm1d(num_features=32)
299322
300323
"""
301-
302-
def _get_param_shape(self, inputs_shape):
303-
if self.data_format == 'channels_last':
304-
axis = 2
305-
elif self.data_format == 'channels_first':
306-
axis = 1
307-
else:
308-
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
309-
310-
if self.num_features is None:
311-
channels = inputs_shape[axis]
312-
else:
313-
channels = self.num_features
314-
params_shape = [1] * 3
315-
params_shape[axis] = channels
316-
317-
axes = [i for i in range(3) if i != axis]
318-
return params_shape, axes
324+
def _check_input_shape(self, inputs):
325+
if inputs.ndim != 2 and inputs.ndim != 3:
326+
raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim))
319327

320328

321329
class BatchNorm2d(BatchNorm):
@@ -335,24 +343,9 @@ class BatchNorm2d(BatchNorm):
335343
>>> bn = tl.layers.BatchNorm2d(num_features=32)
336344
337345
"""
338-
339-
def _get_param_shape(self, inputs_shape):
340-
if self.data_format == 'channels_last':
341-
axis = 3
342-
elif self.data_format == 'channels_first':
343-
axis = 1
344-
else:
345-
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
346-
347-
if self.num_features is None:
348-
channels = inputs_shape[axis]
349-
else:
350-
channels = self.num_features
351-
params_shape = [1] * 4
352-
params_shape[axis] = channels
353-
354-
axes = [i for i in range(4) if i != axis]
355-
return params_shape, axes
346+
def _check_input_shape(self, inputs):
347+
if inputs.ndim != 4:
348+
raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim))
356349

357350

358351
class BatchNorm3d(BatchNorm):
@@ -372,24 +365,9 @@ class BatchNorm3d(BatchNorm):
372365
>>> bn = tl.layers.BatchNorm3d(num_features=32)
373366
374367
"""
375-
376-
def _get_param_shape(self, inputs_shape):
377-
if self.data_format == 'channels_last':
378-
axis = 4
379-
elif self.data_format == 'channels_first':
380-
axis = 1
381-
else:
382-
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
383-
384-
if self.num_features is None:
385-
channels = inputs_shape[axis]
386-
else:
387-
channels = self.num_features
388-
params_shape = [1] * 5
389-
params_shape[axis] = channels
390-
391-
axes = [i for i in range(5) if i != axis]
392-
return params_shape, axes
368+
def _check_input_shape(self, inputs):
369+
if inputs.ndim != 5:
370+
raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim))
393371

394372

395373
class InstanceNorm(Layer):

0 commit comments

Comments
 (0)