Skip to content

Commit e0f841a

Browse files
committed
black linting
1 parent 015373f commit e0f841a

File tree

1 file changed

+55
-38
lines changed

1 file changed

+55
-38
lines changed

tests/test_ensemble_layers.py

Lines changed: 55 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,19 @@
88
import unittest as ut
99

1010
from graphium.nn.base_layers import FCLayer, MLP
11-
from graphium.nn.ensemble_layers import EnsembleLinear, EnsembleFCLayer, EnsembleMLP, EnsembleMuReadoutGraphium
11+
from graphium.nn.ensemble_layers import (
12+
EnsembleLinear,
13+
EnsembleFCLayer,
14+
EnsembleMLP,
15+
EnsembleMuReadoutGraphium,
16+
)
1217

1318

1419
class test_Ensemble_Layers(ut.TestCase):
15-
1620
# for drop_rate=0.5, test if the output shape is correct
17-
def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int):
18-
21+
def check_ensemble_linear(
22+
self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim: int
23+
):
1924
msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"
2025

2126
# Create EnsembleLinear instance
@@ -37,13 +42,11 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
3742

3843
# Make sure that the outputs of the individual layers are the same as the ensemble output
3944
for i, linear_layer in enumerate(linear_layers):
40-
4145
individual_output = linear_layer(input_tensor)
4246
individual_output = individual_output.detach().numpy()
4347
ensemble_output_i = ensemble_output[i].detach().numpy()
4448
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
4549

46-
4750
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
4851
if more_batch_dim:
4952
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
@@ -58,7 +61,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
5861

5962
# Make sure that the outputs of the individual layers are the same as the ensemble output
6063
for i, linear_layer in enumerate(linear_layers):
61-
6264
if more_batch_dim:
6365
individual_output = linear_layer(input_tensor[:, i])
6466
ensemble_output_i = ensemble_output[:, i]
@@ -69,8 +71,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
6971
ensemble_output_i = ensemble_output_i.detach().numpy()
7072
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
7173

72-
73-
7474
def test_ensemble_linear(self):
7575
# more_batch_dim=0
7676
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
@@ -87,10 +87,16 @@ def test_ensemble_linear(self):
8787
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7)
8888
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)
8989

90-
9190
# for drop_rate=0.5, test if the output shape is correct
92-
def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, is_readout_layer=False):
93-
91+
def check_ensemble_fclayer(
92+
self,
93+
in_dim: int,
94+
out_dim: int,
95+
num_ensemble: int,
96+
batch_size: int,
97+
more_batch_dim: int,
98+
is_readout_layer=False,
99+
):
94100
msg = f"Testing EnsembleFCLayer with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"
95101

96102
# Create EnsembleFCLayer instance
@@ -112,13 +118,11 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
112118

113119
# Make sure that the outputs of the individual layers are the same as the ensemble output
114120
for i, fc_layer in enumerate(fc_layers):
115-
116121
individual_output = fc_layer(input_tensor)
117122
individual_output = individual_output.detach().numpy()
118123
ensemble_output_i = ensemble_output[i].detach().numpy()
119124
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
120125

121-
122126
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
123127
if more_batch_dim:
124128
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
@@ -133,7 +137,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
133137

134138
# Make sure that the outputs of the individual layers are the same as the ensemble output
135139
for i, fc_layer in enumerate(fc_layers):
136-
137140
if more_batch_dim:
138141
individual_output = fc_layer(input_tensor[:, i])
139142
ensemble_output_i = ensemble_output[:, i]
@@ -144,8 +147,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
144147
ensemble_output_i = ensemble_output_i.detach().numpy()
145148
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
146149

147-
148-
149150
def test_ensemble_fclayer(self):
150151
# more_batch_dim=0
151152
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
@@ -163,24 +164,39 @@ def test_ensemble_fclayer(self):
163164
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)
164165

165166
# Test `is_readout_layer`
166-
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True)
167-
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True)
168-
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True)
169-
170-
171-
167+
self.check_ensemble_fclayer(
168+
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True
169+
)
170+
self.check_ensemble_fclayer(
171+
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True
172+
)
173+
self.check_ensemble_fclayer(
174+
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True
175+
)
172176

173177
# for drop_rate=0.5, test if the output shape is correct
174-
def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, last_layer_is_readout=False):
175-
178+
def check_ensemble_mlp(
179+
self,
180+
in_dim: int,
181+
out_dim: int,
182+
num_ensemble: int,
183+
batch_size: int,
184+
more_batch_dim: int,
185+
last_layer_is_readout=False,
186+
):
176187
msg = f"Testing EnsembleMLP with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"
177188

178189
# Create EnsembleMLP instance
179190
hidden_dims = [17, 17, 17]
180-
ensemble_mlp = EnsembleMLP(in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout)
191+
ensemble_mlp = EnsembleMLP(
192+
in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout
193+
)
181194

182195
# Create equivalent separate MLP layers with synchronized weights and biases
183-
mlps = [MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) for _ in range(num_ensemble)]
196+
mlps = [
197+
MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout)
198+
for _ in range(num_ensemble)
199+
]
184200
for i, mlp in enumerate(mlps):
185201
for j, layer in enumerate(mlp.fully_connected):
186202
layer.linear.weight.data = ensemble_mlp.fully_connected[j].linear.weight.data[i]
@@ -196,13 +212,11 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
196212

197213
# Make sure that the outputs of the individual layers are the same as the ensemble output
198214
for i, mlp in enumerate(mlps):
199-
200215
individual_output = mlp(input_tensor)
201216
individual_output = individual_output.detach().numpy()
202217
ensemble_output_i = ensemble_output[i].detach().numpy()
203218
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
204219

205-
206220
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
207221
if more_batch_dim:
208222
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
@@ -217,7 +231,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
217231

218232
# Make sure that the outputs of the individual layers are the same as the ensemble output
219233
for i, mlp in enumerate(mlps):
220-
221234
if more_batch_dim:
222235
individual_output = mlp(input_tensor[:, i])
223236
ensemble_output_i = ensemble_output[:, i]
@@ -228,8 +241,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
228241
ensemble_output_i = ensemble_output_i.detach().numpy()
229242
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
230243

231-
232-
233244
def test_ensemble_mlp(self):
234245
# more_batch_dim=0
235246
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
@@ -247,10 +258,16 @@ def test_ensemble_mlp(self):
247258
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)
248259

249260
# Test `last_layer_is_readout`
250-
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True)
251-
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True)
252-
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True)
253-
254-
255-
if __name__ == '__main__':
261+
self.check_ensemble_mlp(
262+
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True
263+
)
264+
self.check_ensemble_mlp(
265+
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True
266+
)
267+
self.check_ensemble_mlp(
268+
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True
269+
)
270+
271+
272+
if __name__ == "__main__":
256273
ut.main()

0 commit comments

Comments
 (0)