Skip to content

Commit 7f74569

Browse files
authored
Merge pull request #1040 from tensorlayer/ChrisWu1997-patch-1
Refactor batch normalization
2 parents 7ff88b3 + 3aba313 commit 7f74569

File tree

6 files changed

+126
-73
lines changed

6 files changed

+126
-73
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,14 @@ To release a new version, please update the changelog as followed:
8080

8181
### Fixed
8282
- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033)
83+
- BN updates: fix BatchNorm1d for 2D data, refactored (#PR 1040)
8384

8485
### Removed
8586

8687
### Security
8788

8889
### Contributors
90+
- @ChrisWu1997: #1040
8991

9092

9193
## [2.2.1]

tensorlayer/files/utils.py

+4
Original file line numberDiff line numberDiff line change
@@ -2666,6 +2666,10 @@ def _load_weights_from_hdf5_group(f, layers, skip=False):
26662666
elif isinstance(layer, tl.layers.Layer):
26672667
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
26682668
for iid, w_name in enumerate(weight_names):
2669+
# FIXME : this is only for compatibility
2670+
if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1:
2671+
assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze())
2672+
continue
26692673
assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]))
26702674
else:
26712675
raise Exception("Only layer or model can be saved into hdf5.")

tensorlayer/layers/normalization.py

+41-66
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):
108108

109109
def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
110110
"""Data Format aware version of tf.nn.batch_normalization."""
111+
if data_format == 'channels_last':
112+
mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1])
113+
variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1])
114+
offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1])
115+
scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1])
116+
elif data_format == 'channels_first':
117+
mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2))
118+
variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2))
119+
offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2))
120+
scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2))
121+
else:
122+
raise ValueError('invalid data_format: %s' % data_format)
123+
111124
with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]):
112125
inv = math_ops.rsqrt(variance + variance_epsilon)
113126
if scale is not None:
@@ -204,13 +217,10 @@ def __init__(
204217
self.moving_var_init = moving_var_init
205218
self.num_features = num_features
206219

220+
self.channel_axis = -1 if data_format == 'channels_last' else 1
221+
self.axes = None
222+
207223
if num_features is not None:
208-
if not isinstance(self, BatchNorm1d) and not isinstance(self, BatchNorm2d) and not isinstance(self,
209-
BatchNorm3d):
210-
raise ValueError(
211-
"Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm "
212-
"if you want to specify 'num_features'."
213-
)
214224
self.build(None)
215225
self._built = True
216226

@@ -233,21 +243,23 @@ def __repr__(self):
233243

234244
def _get_param_shape(self, inputs_shape):
235245
if self.data_format == 'channels_last':
236-
axis = len(inputs_shape) - 1
246+
axis = -1
237247
elif self.data_format == 'channels_first':
238248
axis = 1
239249
else:
240250
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
241251

242252
channels = inputs_shape[axis]
243-
params_shape = [1] * len(inputs_shape)
244-
params_shape[axis] = channels
253+
params_shape = [channels]
245254

246-
axes = [i for i in range(len(inputs_shape)) if i != axis]
247-
return params_shape, axes
255+
return params_shape
256+
257+
def _check_input_shape(self, inputs):
258+
if inputs.ndim <= 1:
259+
raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))
248260

249261
def build(self, inputs_shape):
250-
params_shape, self.axes = self._get_param_shape(inputs_shape)
262+
params_shape = [self.num_features] if self.num_features is not None else self._get_param_shape(inputs_shape)
251263

252264
self.beta, self.gamma = None, None
253265
if self.beta_init:
@@ -264,7 +276,12 @@ def build(self, inputs_shape):
264276
)
265277

266278
def forward(self, inputs):
267-
mean, var = tf.nn.moments(inputs, self.axes, keepdims=True)
279+
self._check_input_shape(inputs)
280+
281+
if self.axes is None:
282+
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]
283+
284+
mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
268285
if self.is_train:
269286
# update moving_mean and moving_var
270287
self.moving_mean = moving_averages.assign_moving_average(
@@ -282,8 +299,8 @@ def forward(self, inputs):
282299

283300

284301
class BatchNorm1d(BatchNorm):
285-
"""The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
286-
inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
302+
"""The :class:`BatchNorm1d` applies Batch Normalization over 2D/3D input (a mini-batch of 1D
303+
inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
287304
See more details in :class:`BatchNorm`.
288305
289306
Examples
@@ -299,23 +316,9 @@ class BatchNorm1d(BatchNorm):
299316
300317
"""
301318

302-
def _get_param_shape(self, inputs_shape):
303-
if self.data_format == 'channels_last':
304-
axis = 2
305-
elif self.data_format == 'channels_first':
306-
axis = 1
307-
else:
308-
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
309-
310-
if self.num_features is None:
311-
channels = inputs_shape[axis]
312-
else:
313-
channels = self.num_features
314-
params_shape = [1] * 3
315-
params_shape[axis] = channels
316-
317-
axes = [i for i in range(3) if i != axis]
318-
return params_shape, axes
319+
def _check_input_shape(self, inputs):
320+
if inputs.ndim != 2 and inputs.ndim != 3:
321+
raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim))
319322

320323

321324
class BatchNorm2d(BatchNorm):
@@ -336,23 +339,9 @@ class BatchNorm2d(BatchNorm):
336339
337340
"""
338341

339-
def _get_param_shape(self, inputs_shape):
340-
if self.data_format == 'channels_last':
341-
axis = 3
342-
elif self.data_format == 'channels_first':
343-
axis = 1
344-
else:
345-
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
346-
347-
if self.num_features is None:
348-
channels = inputs_shape[axis]
349-
else:
350-
channels = self.num_features
351-
params_shape = [1] * 4
352-
params_shape[axis] = channels
353-
354-
axes = [i for i in range(4) if i != axis]
355-
return params_shape, axes
342+
def _check_input_shape(self, inputs):
343+
if inputs.ndim != 4:
344+
raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim))
356345

357346

358347
class BatchNorm3d(BatchNorm):
@@ -373,23 +362,9 @@ class BatchNorm3d(BatchNorm):
373362
374363
"""
375364

376-
def _get_param_shape(self, inputs_shape):
377-
if self.data_format == 'channels_last':
378-
axis = 4
379-
elif self.data_format == 'channels_first':
380-
axis = 1
381-
else:
382-
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))
383-
384-
if self.num_features is None:
385-
channels = inputs_shape[axis]
386-
else:
387-
channels = self.num_features
388-
params_shape = [1] * 5
389-
params_shape[axis] = channels
390-
391-
axes = [i for i in range(5) if i != axis]
392-
return params_shape, axes
365+
def _check_input_shape(self, inputs):
366+
if inputs.ndim != 5:
367+
raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim))
393368

394369

395370
class InstanceNorm(Layer):

tensorlayer/models/mobilenetv1.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def restore_params(network, path='models'):
4343
expected_bytes=25600116
4444
) # ls -al
4545
params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
46-
for idx, net_weight in enumerate(network.all_weights):
47-
if 'batchnorm' in net_weight.name:
48-
params[idx] = params[idx].reshape(1, 1, 1, -1)
46+
# for idx, net_weight in enumerate(network.all_weights):
47+
# if 'batchnorm' in net_weight.name:
48+
# params[idx] = params[idx].reshape(1, 1, 1, -1)
4949
assign_weights(params[:len(network.all_weights)], network)
5050
del params
5151

tensorlayer/models/resnet.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,8 @@ def restore_params(network, path='models'):
194194
continue
195195
w_names = list(f[layer.name])
196196
params = [f[layer.name][n][:] for n in w_names]
197-
if 'bn' in layer.name:
198-
params = [x.reshape(1, 1, 1, -1) for x in params]
197+
# if 'bn' in layer.name:
198+
# params = [x.reshape(1, 1, 1, -1) for x in params]
199199
assign_weights(params, layer)
200200
del params
201201

tests/layers/test_layers_normalization.py

+74-2
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase):
1818
@classmethod
1919
def setUpClass(cls):
2020

21+
x_0_input_shape = [None, 10]
2122
x_1_input_shape = [None, 100, 1]
2223
x_2_input_shape = [None, 100, 100, 3]
2324
x_3_input_shape = [None, 100, 100, 100, 3]
2425
batchsize = 2
2526

27+
cls.x0 = tf.random.normal([batchsize] + x_0_input_shape[1:])
2628
cls.x1 = tf.random.normal([batchsize] + x_1_input_shape[1:])
2729
cls.x2 = tf.random.normal([batchsize] + x_2_input_shape[1:])
2830
cls.x3 = tf.random.normal([batchsize] + x_3_input_shape[1:])
@@ -36,16 +38,58 @@ def setUpClass(cls):
3638

3739
ni_2 = Input(x_2_input_shape, name='test_ni2')
3840
nn_2 = Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), name='test_conv2d')(ni_2)
39-
n2_b = BatchNorm2d(name='test_bn2d')(nn_2)
41+
n2_b = BatchNorm(name='test_bn2d')(nn_2)
4042
cls.n2_b = n2_b
4143
cls.base_2d = Model(inputs=ni_2, outputs=n2_b, name='test_base_2d')
4244

4345
ni_3 = Input(x_3_input_shape, name='test_ni2')
4446
nn_3 = Conv3d(n_filter=32, filter_size=(3, 3, 3), strides=(2, 2, 2), name='test_conv3d')(ni_3)
45-
n3_b = BatchNorm3d(name='test_bn3d')(nn_3)
47+
n3_b = BatchNorm(name='test_bn3d')(nn_3)
4648
cls.n3_b = n3_b
4749
cls.base_3d = Model(inputs=ni_3, outputs=n3_b, name='test_base_3d')
4850

51+
class bn_0d_model(Model):
52+
53+
def __init__(self):
54+
super(bn_0d_model, self).__init__()
55+
self.fc = Dense(32, in_channels=10)
56+
self.bn = BatchNorm(num_features=32, name='test_bn1d')
57+
58+
def forward(self, x):
59+
x = self.bn(self.fc(x))
60+
return x
61+
62+
dynamic_base = bn_0d_model()
63+
cls.n0_b = dynamic_base(cls.x0, is_train=True)
64+
65+
## 0D ========================================================================
66+
67+
nin_0 = Input(x_0_input_shape, name='test_in1')
68+
69+
n0 = Dense(32)(nin_0)
70+
n0 = BatchNorm1d(name='test_bn0d')(n0)
71+
72+
cls.n0 = n0
73+
74+
cls.static_0d = Model(inputs=nin_0, outputs=n0)
75+
76+
class bn_0d_model(Model):
77+
78+
def __init__(self):
79+
super(bn_0d_model, self).__init__(name='test_bn_0d_model')
80+
self.fc = Dense(32, in_channels=10)
81+
self.bn = BatchNorm1d(num_features=32, name='test_bn1d')
82+
83+
def forward(self, x):
84+
x = self.bn(self.fc(x))
85+
return x
86+
87+
cls.dynamic_0d = bn_0d_model()
88+
89+
print("Printing BatchNorm0d")
90+
print(cls.static_0d)
91+
print(cls.dynamic_0d)
92+
4993
## 1D ========================================================================
5094

5195
nin_1 = Input(x_1_input_shape, name='test_in1')
@@ -147,6 +191,14 @@ def test_BatchNorm(self):
147191
self.assertEqual(self.n3_b.shape[1:], (50, 50, 50, 32))
148192
out = self.base_3d(self.x3, is_train=True)
149193

194+
self.assertEqual(self.n0_b.shape[1:], (32))
195+
print("test_BatchNorm OK")
196+
197+
def test_BatchNorm0d(self):
198+
self.assertEqual(self.n0.shape[1:], (32))
199+
out = self.static_0d(self.x0, is_train=True)
200+
out = self.dynamic_0d(self.x0, is_train=True)
201+
150202
def test_BatchNorm1d(self):
151203
self.assertEqual(self.n1.shape[1:], (50, 32))
152204
out = self.static_1d(self.x1, is_train=True)
@@ -189,6 +241,26 @@ def test_exception(self):
189241
self.assertIsInstance(e, ValueError)
190242
print(e)
191243

244+
def test_input_shape(self):
245+
try:
246+
bn = BatchNorm1d(num_features=32)
247+
out = bn(self.x2)
248+
except Exception as e:
249+
self.assertIsInstance(e, ValueError)
250+
print(e)
251+
try:
252+
bn = BatchNorm2d(num_features=32)
253+
out = bn(self.x3)
254+
except Exception as e:
255+
self.assertIsInstance(e, ValueError)
256+
print(e)
257+
try:
258+
bn = BatchNorm3d(num_features=32)
259+
out = bn(self.x1)
260+
except Exception as e:
261+
self.assertIsInstance(e, ValueError)
262+
print(e)
263+
192264

193265
if __name__ == '__main__':
194266

0 commit comments

Comments
 (0)