Skip to content

Commit 92de448

Browse files
committed
update BN unittest
1 parent e41e44e commit 92de448

File tree

2 files changed

+73
-8
lines changed

2 files changed

+73
-8
lines changed

tensorlayer/layers/normalization.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,12 +221,6 @@ def __init__(
221221
self.axes = None
222222

223223
if num_features is not None:
224-
if not isinstance(self, BatchNorm1d) and not isinstance(self, BatchNorm2d) and not isinstance(self,
225-
BatchNorm3d):
226-
raise ValueError(
227-
"Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm "
228-
"if you want to specify 'num_features'."
229-
)
230224
self.build(None)
231225
self._built = True
232226

tests/layers/test_layers_normalization.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase):
1818
@classmethod
1919
def setUpClass(cls):
2020

21+
x_0_input_shape = [None, 10]
2122
x_1_input_shape = [None, 100, 1]
2223
x_2_input_shape = [None, 100, 100, 3]
2324
x_3_input_shape = [None, 100, 100, 100, 3]
2425
batchsize = 2
2526

27+
cls.x0 = tf.random.normal([batchsize] + x_0_input_shape[1:])
2628
cls.x1 = tf.random.normal([batchsize] + x_1_input_shape[1:])
2729
cls.x2 = tf.random.normal([batchsize] + x_2_input_shape[1:])
2830
cls.x3 = tf.random.normal([batchsize] + x_3_input_shape[1:])
@@ -36,16 +38,58 @@ def setUpClass(cls):
3638

3739
ni_2 = Input(x_2_input_shape, name='test_ni2')
3840
nn_2 = Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), name='test_conv2d')(ni_2)
39-
n2_b = BatchNorm2d(name='test_bn2d')(nn_2)
41+
n2_b = BatchNorm(name='test_bn2d')(nn_2)
4042
cls.n2_b = n2_b
4143
cls.base_2d = Model(inputs=ni_2, outputs=n2_b, name='test_base_2d')
4244

4345
ni_3 = Input(x_3_input_shape, name='test_ni2')
4446
nn_3 = Conv3d(n_filter=32, filter_size=(3, 3, 3), strides=(2, 2, 2), name='test_conv3d')(ni_3)
45-
n3_b = BatchNorm3d(name='test_bn3d')(nn_3)
47+
n3_b = BatchNorm(name='test_bn3d')(nn_3)
4648
cls.n3_b = n3_b
4749
cls.base_3d = Model(inputs=ni_3, outputs=n3_b, name='test_base_3d')
4850

51+
class bn_0d_model(Model):
52+
53+
def __init__(self):
54+
super(bn_0d_model, self).__init__()
55+
self.fc = Dense(32, in_channels=10)
56+
self.bn = BatchNorm(num_features=32, name='test_bn1d')
57+
58+
def forward(self, x):
59+
x = self.bn(self.fc(x))
60+
return x
61+
62+
dynamic_base = bn_0d_model()
63+
cls.n0_b = dynamic_base(cls.x0, is_train=True)
64+
65+
## 0D ========================================================================
66+
67+
nin_0 = Input(x_0_input_shape, name='test_in1')
68+
69+
n0 = Dense(32)(nin_0)
70+
n0 = BatchNorm1d(name='test_bn0d')(n0)
71+
72+
cls.n0 = n0
73+
74+
cls.static_0d = Model(inputs=nin_0, outputs=n0)
75+
76+
class bn_0d_model(Model):
77+
78+
def __init__(self):
79+
super(bn_0d_model, self).__init__(name='test_bn_0d_model')
80+
self.fc = Dense(32, in_channels=10)
81+
self.bn = BatchNorm1d(num_features=32, name='test_bn1d')
82+
83+
def forward(self, x):
84+
x = self.bn(self.fc(x))
85+
return x
86+
87+
cls.dynamic_0d = bn_0d_model()
88+
89+
print("Printing BatchNorm0d")
90+
print(cls.static_0d)
91+
print(cls.dynamic_0d)
92+
4993
## 1D ========================================================================
5094

5195
nin_1 = Input(x_1_input_shape, name='test_in1')
@@ -147,6 +191,14 @@ def test_BatchNorm(self):
147191
self.assertEqual(self.n3_b.shape[1:], (50, 50, 50, 32))
148192
out = self.base_3d(self.x3, is_train=True)
149193

194+
self.assertEqual(self.n0_b.shape[1:], (32))
195+
print("test_BatchNorm OK")
196+
197+
def test_BatchNorm0d(self):
198+
self.assertEqual(self.n0.shape[1:], (32))
199+
out = self.static_0d(self.x0, is_train=True)
200+
out = self.dynamic_0d(self.x0, is_train=True)
201+
150202
def test_BatchNorm1d(self):
151203
self.assertEqual(self.n1.shape[1:], (50, 32))
152204
out = self.static_1d(self.x1, is_train=True)
@@ -189,6 +241,25 @@ def test_exception(self):
189241
self.assertIsInstance(e, ValueError)
190242
print(e)
191243

244+
def test_input_shape(self):
245+
try:
246+
bn = BatchNorm1d(num_features=32)
247+
out = bn(self.x2)
248+
except Exception as e:
249+
self.assertIsInstance(e, ValueError)
250+
print(e)
251+
try:
252+
bn = BatchNorm2d(num_features=32)
253+
out = bn(self.x3)
254+
except Exception as e:
255+
self.assertIsInstance(e, ValueError)
256+
print(e)
257+
try:
258+
bn = BatchNorm3d(num_features=32)
259+
out = bn(self.x1)
260+
except Exception as e:
261+
self.assertIsInstance(e, ValueError)
262+
print(e)
192263

193264
if __name__ == '__main__':
194265

0 commit comments

Comments
 (0)