8
8
import unittest as ut
9
9
10
10
from graphium .nn .base_layers import FCLayer , MLP
11
- from graphium .nn .ensemble_layers import EnsembleLinear , EnsembleFCLayer , EnsembleMLP , EnsembleMuReadoutGraphium
11
+ from graphium .nn .ensemble_layers import (
12
+ EnsembleLinear ,
13
+ EnsembleFCLayer ,
14
+ EnsembleMLP ,
15
+ EnsembleMuReadoutGraphium ,
16
+ )
12
17
13
18
14
19
class test_Ensemble_Layers (ut .TestCase ):
15
-
16
20
# for drop_rate=0.5, test if the output shape is correct
17
- def check_ensemble_linear (self , in_dim : int , out_dim : int , num_ensemble : int , batch_size : int , more_batch_dim :int ):
18
-
21
+ def check_ensemble_linear (
22
+ self , in_dim : int , out_dim : int , num_ensemble : int , batch_size : int , more_batch_dim : int
23
+ ):
19
24
msg = f"Testing EnsembleLinear with in_dim={ in_dim } , out_dim={ out_dim } , num_ensemble={ num_ensemble } , batch_size={ batch_size } , more_batch_dim={ more_batch_dim } "
20
25
21
26
# Create EnsembleLinear instance
@@ -37,13 +42,11 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
37
42
38
43
# Make sure that the outputs of the individual layers are the same as the ensemble output
39
44
for i , linear_layer in enumerate (linear_layers ):
40
-
41
45
individual_output = linear_layer (input_tensor )
42
46
individual_output = individual_output .detach ().numpy ()
43
47
ensemble_output_i = ensemble_output [i ].detach ().numpy ()
44
48
np .testing .assert_allclose (ensemble_output_i , individual_output , atol = 1e-5 , err_msg = msg )
45
49
46
-
47
50
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
48
51
if more_batch_dim :
49
52
out_shape = (more_batch_dim , num_ensemble , batch_size , out_dim )
@@ -58,7 +61,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
58
61
59
62
# Make sure that the outputs of the individual layers are the same as the ensemble output
60
63
for i , linear_layer in enumerate (linear_layers ):
61
-
62
64
if more_batch_dim :
63
65
individual_output = linear_layer (input_tensor [:, i ])
64
66
ensemble_output_i = ensemble_output [:, i ]
@@ -69,8 +71,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
69
71
ensemble_output_i = ensemble_output_i .detach ().numpy ()
70
72
np .testing .assert_allclose (ensemble_output_i , individual_output , atol = 1e-5 , err_msg = msg )
71
73
72
-
73
-
74
74
def test_ensemble_linear (self ):
75
75
# more_batch_dim=0
76
76
self .check_ensemble_linear (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 )
@@ -87,10 +87,16 @@ def test_ensemble_linear(self):
87
87
self .check_ensemble_linear (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 1 , more_batch_dim = 7 )
88
88
self .check_ensemble_linear (in_dim = 11 , out_dim = 5 , num_ensemble = 1 , batch_size = 13 , more_batch_dim = 7 )
89
89
90
-
91
90
# for drop_rate=0.5, test if the output shape is correct
92
- def check_ensemble_fclayer (self , in_dim : int , out_dim : int , num_ensemble : int , batch_size : int , more_batch_dim :int , is_readout_layer = False ):
93
-
91
+ def check_ensemble_fclayer (
92
+ self ,
93
+ in_dim : int ,
94
+ out_dim : int ,
95
+ num_ensemble : int ,
96
+ batch_size : int ,
97
+ more_batch_dim : int ,
98
+ is_readout_layer = False ,
99
+ ):
94
100
msg = f"Testing EnsembleFCLayer with in_dim={ in_dim } , out_dim={ out_dim } , num_ensemble={ num_ensemble } , batch_size={ batch_size } , more_batch_dim={ more_batch_dim } "
95
101
96
102
# Create EnsembleFCLayer instance
@@ -112,13 +118,11 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
112
118
113
119
# Make sure that the outputs of the individual layers are the same as the ensemble output
114
120
for i , fc_layer in enumerate (fc_layers ):
115
-
116
121
individual_output = fc_layer (input_tensor )
117
122
individual_output = individual_output .detach ().numpy ()
118
123
ensemble_output_i = ensemble_output [i ].detach ().numpy ()
119
124
np .testing .assert_allclose (ensemble_output_i , individual_output , atol = 1e-5 , err_msg = msg )
120
125
121
-
122
126
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
123
127
if more_batch_dim :
124
128
out_shape = (more_batch_dim , num_ensemble , batch_size , out_dim )
@@ -133,7 +137,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
133
137
134
138
# Make sure that the outputs of the individual layers are the same as the ensemble output
135
139
for i , fc_layer in enumerate (fc_layers ):
136
-
137
140
if more_batch_dim :
138
141
individual_output = fc_layer (input_tensor [:, i ])
139
142
ensemble_output_i = ensemble_output [:, i ]
@@ -144,8 +147,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
144
147
ensemble_output_i = ensemble_output_i .detach ().numpy ()
145
148
np .testing .assert_allclose (ensemble_output_i , individual_output , atol = 1e-5 , err_msg = msg )
146
149
147
-
148
-
149
150
def test_ensemble_fclayer (self ):
150
151
# more_batch_dim=0
151
152
self .check_ensemble_fclayer (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 )
@@ -163,24 +164,39 @@ def test_ensemble_fclayer(self):
163
164
self .check_ensemble_fclayer (in_dim = 11 , out_dim = 5 , num_ensemble = 1 , batch_size = 13 , more_batch_dim = 7 )
164
165
165
166
# Test `is_readout_layer`
166
- self .check_ensemble_fclayer (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 , is_readout_layer = True )
167
- self .check_ensemble_fclayer (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 1 , is_readout_layer = True )
168
- self .check_ensemble_fclayer (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 7 , is_readout_layer = True )
169
-
170
-
171
-
167
+ self .check_ensemble_fclayer (
168
+ in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 , is_readout_layer = True
169
+ )
170
+ self .check_ensemble_fclayer (
171
+ in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 1 , is_readout_layer = True
172
+ )
173
+ self .check_ensemble_fclayer (
174
+ in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 7 , is_readout_layer = True
175
+ )
172
176
173
177
# for drop_rate=0.5, test if the output shape is correct
174
- def check_ensemble_mlp (self , in_dim : int , out_dim : int , num_ensemble : int , batch_size : int , more_batch_dim :int , last_layer_is_readout = False ):
175
-
178
+ def check_ensemble_mlp (
179
+ self ,
180
+ in_dim : int ,
181
+ out_dim : int ,
182
+ num_ensemble : int ,
183
+ batch_size : int ,
184
+ more_batch_dim : int ,
185
+ last_layer_is_readout = False ,
186
+ ):
176
187
msg = f"Testing EnsembleMLP with in_dim={ in_dim } , out_dim={ out_dim } , num_ensemble={ num_ensemble } , batch_size={ batch_size } , more_batch_dim={ more_batch_dim } "
177
188
178
189
# Create EnsembleMLP instance
179
190
hidden_dims = [17 , 17 , 17 ]
180
- ensemble_mlp = EnsembleMLP (in_dim , hidden_dims , out_dim , num_ensemble , last_layer_is_readout = last_layer_is_readout )
191
+ ensemble_mlp = EnsembleMLP (
192
+ in_dim , hidden_dims , out_dim , num_ensemble , last_layer_is_readout = last_layer_is_readout
193
+ )
181
194
182
195
# Create equivalent separate MLP layers with synchronized weights and biases
183
- mlps = [MLP (in_dim , hidden_dims , out_dim , last_layer_is_readout = last_layer_is_readout ) for _ in range (num_ensemble )]
196
+ mlps = [
197
+ MLP (in_dim , hidden_dims , out_dim , last_layer_is_readout = last_layer_is_readout )
198
+ for _ in range (num_ensemble )
199
+ ]
184
200
for i , mlp in enumerate (mlps ):
185
201
for j , layer in enumerate (mlp .fully_connected ):
186
202
layer .linear .weight .data = ensemble_mlp .fully_connected [j ].linear .weight .data [i ]
@@ -196,13 +212,11 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
196
212
197
213
# Make sure that the outputs of the individual layers are the same as the ensemble output
198
214
for i , mlp in enumerate (mlps ):
199
-
200
215
individual_output = mlp (input_tensor )
201
216
individual_output = individual_output .detach ().numpy ()
202
217
ensemble_output_i = ensemble_output [i ].detach ().numpy ()
203
218
np .testing .assert_allclose (ensemble_output_i , individual_output , atol = 1e-5 , err_msg = msg )
204
219
205
-
206
220
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
207
221
if more_batch_dim :
208
222
out_shape = (more_batch_dim , num_ensemble , batch_size , out_dim )
@@ -217,7 +231,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
217
231
218
232
# Make sure that the outputs of the individual layers are the same as the ensemble output
219
233
for i , mlp in enumerate (mlps ):
220
-
221
234
if more_batch_dim :
222
235
individual_output = mlp (input_tensor [:, i ])
223
236
ensemble_output_i = ensemble_output [:, i ]
@@ -228,8 +241,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
228
241
ensemble_output_i = ensemble_output_i .detach ().numpy ()
229
242
np .testing .assert_allclose (ensemble_output_i , individual_output , atol = 1e-5 , err_msg = msg )
230
243
231
-
232
-
233
244
def test_ensemble_mlp (self ):
234
245
# more_batch_dim=0
235
246
self .check_ensemble_mlp (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 )
@@ -247,10 +258,16 @@ def test_ensemble_mlp(self):
247
258
self .check_ensemble_mlp (in_dim = 11 , out_dim = 5 , num_ensemble = 1 , batch_size = 13 , more_batch_dim = 7 )
248
259
249
260
# Test `last_layer_is_readout`
250
- self .check_ensemble_mlp (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 , last_layer_is_readout = True )
251
- self .check_ensemble_mlp (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 1 , last_layer_is_readout = True )
252
- self .check_ensemble_mlp (in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 7 , last_layer_is_readout = True )
253
-
254
-
255
- if __name__ == '__main__' :
261
+ self .check_ensemble_mlp (
262
+ in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 0 , last_layer_is_readout = True
263
+ )
264
+ self .check_ensemble_mlp (
265
+ in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 1 , last_layer_is_readout = True
266
+ )
267
+ self .check_ensemble_mlp (
268
+ in_dim = 11 , out_dim = 5 , num_ensemble = 3 , batch_size = 13 , more_batch_dim = 7 , last_layer_is_readout = True
269
+ )
270
+
271
+
272
+ if __name__ == "__main__" :
256
273
ut .main ()
0 commit comments