@@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase):
18
18
@classmethod
19
19
def setUpClass (cls ):
20
20
21
+ x_0_input_shape = [None , 10 ]
21
22
x_1_input_shape = [None , 100 , 1 ]
22
23
x_2_input_shape = [None , 100 , 100 , 3 ]
23
24
x_3_input_shape = [None , 100 , 100 , 100 , 3 ]
24
25
batchsize = 2
25
26
27
+ cls .x0 = tf .random .normal ([batchsize ] + x_0_input_shape [1 :])
26
28
cls .x1 = tf .random .normal ([batchsize ] + x_1_input_shape [1 :])
27
29
cls .x2 = tf .random .normal ([batchsize ] + x_2_input_shape [1 :])
28
30
cls .x3 = tf .random .normal ([batchsize ] + x_3_input_shape [1 :])
@@ -36,16 +38,58 @@ def setUpClass(cls):
36
38
37
39
ni_2 = Input (x_2_input_shape , name = 'test_ni2' )
38
40
nn_2 = Conv2d (n_filter = 32 , filter_size = (3 , 3 ), strides = (2 , 2 ), name = 'test_conv2d' )(ni_2 )
39
- n2_b = BatchNorm2d (name = 'test_bn2d' )(nn_2 )
41
+ n2_b = BatchNorm (name = 'test_bn2d' )(nn_2 )
40
42
cls .n2_b = n2_b
41
43
cls .base_2d = Model (inputs = ni_2 , outputs = n2_b , name = 'test_base_2d' )
42
44
43
45
ni_3 = Input (x_3_input_shape , name = 'test_ni2' )
44
46
nn_3 = Conv3d (n_filter = 32 , filter_size = (3 , 3 , 3 ), strides = (2 , 2 , 2 ), name = 'test_conv3d' )(ni_3 )
45
- n3_b = BatchNorm3d (name = 'test_bn3d' )(nn_3 )
47
+ n3_b = BatchNorm (name = 'test_bn3d' )(nn_3 )
46
48
cls .n3_b = n3_b
47
49
cls .base_3d = Model (inputs = ni_3 , outputs = n3_b , name = 'test_base_3d' )
48
50
51
+ class bn_0d_model (Model ):
52
+
53
+ def __init__ (self ):
54
+ super (bn_0d_model , self ).__init__ ()
55
+ self .fc = Dense (32 , in_channels = 10 )
56
+ self .bn = BatchNorm (num_features = 32 , name = 'test_bn1d' )
57
+
58
+ def forward (self , x ):
59
+ x = self .bn (self .fc (x ))
60
+ return x
61
+
62
+ dynamic_base = bn_0d_model ()
63
+ cls .n0_b = dynamic_base (cls .x0 , is_train = True )
64
+
65
+ ## 0D ========================================================================
66
+
67
+ nin_0 = Input (x_0_input_shape , name = 'test_in1' )
68
+
69
+ n0 = Dense (32 )(nin_0 )
70
+ n0 = BatchNorm1d (name = 'test_bn0d' )(n0 )
71
+
72
+ cls .n0 = n0
73
+
74
+ cls .static_0d = Model (inputs = nin_0 , outputs = n0 )
75
+
76
+ class bn_0d_model (Model ):
77
+
78
+ def __init__ (self ):
79
+ super (bn_0d_model , self ).__init__ (name = 'test_bn_0d_model' )
80
+ self .fc = Dense (32 , in_channels = 10 )
81
+ self .bn = BatchNorm1d (num_features = 32 , name = 'test_bn1d' )
82
+
83
+ def forward (self , x ):
84
+ x = self .bn (self .fc (x ))
85
+ return x
86
+
87
+ cls .dynamic_0d = bn_0d_model ()
88
+
89
+ print ("Printing BatchNorm0d" )
90
+ print (cls .static_0d )
91
+ print (cls .dynamic_0d )
92
+
49
93
## 1D ========================================================================
50
94
51
95
nin_1 = Input (x_1_input_shape , name = 'test_in1' )
@@ -147,6 +191,14 @@ def test_BatchNorm(self):
147
191
self .assertEqual (self .n3_b .shape [1 :], (50 , 50 , 50 , 32 ))
148
192
out = self .base_3d (self .x3 , is_train = True )
149
193
194
+ self .assertEqual (self .n0_b .shape [1 :], (32 ))
195
+ print ("test_BatchNorm OK" )
196
+
197
+ def test_BatchNorm0d (self ):
198
+ self .assertEqual (self .n0 .shape [1 :], (32 ))
199
+ out = self .static_0d (self .x0 , is_train = True )
200
+ out = self .dynamic_0d (self .x0 , is_train = True )
201
+
150
202
def test_BatchNorm1d (self ):
151
203
self .assertEqual (self .n1 .shape [1 :], (50 , 32 ))
152
204
out = self .static_1d (self .x1 , is_train = True )
@@ -189,6 +241,25 @@ def test_exception(self):
189
241
self .assertIsInstance (e , ValueError )
190
242
print (e )
191
243
244
+ def test_input_shape (self ):
245
+ try :
246
+ bn = BatchNorm1d (num_features = 32 )
247
+ out = bn (self .x2 )
248
+ except Exception as e :
249
+ self .assertIsInstance (e , ValueError )
250
+ print (e )
251
+ try :
252
+ bn = BatchNorm2d (num_features = 32 )
253
+ out = bn (self .x3 )
254
+ except Exception as e :
255
+ self .assertIsInstance (e , ValueError )
256
+ print (e )
257
+ try :
258
+ bn = BatchNorm3d (num_features = 32 )
259
+ out = bn (self .x1 )
260
+ except Exception as e :
261
+ self .assertIsInstance (e , ValueError )
262
+ print (e )
192
263
193
264
if __name__ == '__main__' :
194
265
0 commit comments