Skip to content

Commit

Permalink
black linting
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 13, 2023
1 parent 015373f commit e0f841a
Showing 1 changed file with 55 additions and 38 deletions.
93 changes: 55 additions & 38 deletions tests/test_ensemble_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,19 @@
import unittest as ut

from graphium.nn.base_layers import FCLayer, MLP
from graphium.nn.ensemble_layers import EnsembleLinear, EnsembleFCLayer, EnsembleMLP, EnsembleMuReadoutGraphium
from graphium.nn.ensemble_layers import (
EnsembleLinear,
EnsembleFCLayer,
EnsembleMLP,
EnsembleMuReadoutGraphium,
)


class test_Ensemble_Layers(ut.TestCase):

# for drop_rate=0.5, test if the output shape is correct
def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int):

def check_ensemble_linear(
self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim: int
):
msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"

# Create EnsembleLinear instance
Expand All @@ -37,13 +42,11 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba

# Make sure that the outputs of the individual layers are the same as the ensemble output
for i, linear_layer in enumerate(linear_layers):

individual_output = linear_layer(input_tensor)
individual_output = individual_output.detach().numpy()
ensemble_output_i = ensemble_output[i].detach().numpy()
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)


# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
if more_batch_dim:
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
Expand All @@ -58,7 +61,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba

# Make sure that the outputs of the individual layers are the same as the ensemble output
for i, linear_layer in enumerate(linear_layers):

if more_batch_dim:
individual_output = linear_layer(input_tensor[:, i])
ensemble_output_i = ensemble_output[:, i]
Expand All @@ -69,8 +71,6 @@ def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, ba
ensemble_output_i = ensemble_output_i.detach().numpy()
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)



def test_ensemble_linear(self):
# more_batch_dim=0
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
Expand All @@ -87,10 +87,16 @@ def test_ensemble_linear(self):
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7)
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)


# for drop_rate=0.5, test if the output shape is correct
def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, is_readout_layer=False):

def check_ensemble_fclayer(
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
batch_size: int,
more_batch_dim: int,
is_readout_layer=False,
):
msg = f"Testing EnsembleFCLayer with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"

# Create EnsembleFCLayer instance
Expand All @@ -112,13 +118,11 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b

# Make sure that the outputs of the individual layers are the same as the ensemble output
for i, fc_layer in enumerate(fc_layers):

individual_output = fc_layer(input_tensor)
individual_output = individual_output.detach().numpy()
ensemble_output_i = ensemble_output[i].detach().numpy()
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)


# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
if more_batch_dim:
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
Expand All @@ -133,7 +137,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b

# Make sure that the outputs of the individual layers are the same as the ensemble output
for i, fc_layer in enumerate(fc_layers):

if more_batch_dim:
individual_output = fc_layer(input_tensor[:, i])
ensemble_output_i = ensemble_output[:, i]
Expand All @@ -144,8 +147,6 @@ def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, b
ensemble_output_i = ensemble_output_i.detach().numpy()
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)



def test_ensemble_fclayer(self):
# more_batch_dim=0
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
Expand All @@ -163,24 +164,39 @@ def test_ensemble_fclayer(self):
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)

# Test `is_readout_layer`
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True)
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True)
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True)



self.check_ensemble_fclayer(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True
)
self.check_ensemble_fclayer(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True
)
self.check_ensemble_fclayer(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True
)

# for drop_rate=0.5, test if the output shape is correct
def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, last_layer_is_readout=False):

def check_ensemble_mlp(
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
batch_size: int,
more_batch_dim: int,
last_layer_is_readout=False,
):
msg = f"Testing EnsembleMLP with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"

# Create EnsembleMLP instance
hidden_dims = [17, 17, 17]
ensemble_mlp = EnsembleMLP(in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout)
ensemble_mlp = EnsembleMLP(
in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout
)

# Create equivalent separate MLP layers with synchronized weights and biases
mlps = [MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) for _ in range(num_ensemble)]
mlps = [
MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout)
for _ in range(num_ensemble)
]
for i, mlp in enumerate(mlps):
for j, layer in enumerate(mlp.fully_connected):
layer.linear.weight.data = ensemble_mlp.fully_connected[j].linear.weight.data[i]
Expand All @@ -196,13 +212,11 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch

# Make sure that the outputs of the individual layers are the same as the ensemble output
for i, mlp in enumerate(mlps):

individual_output = mlp(input_tensor)
individual_output = individual_output.detach().numpy()
ensemble_output_i = ensemble_output[i].detach().numpy()
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)


# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
if more_batch_dim:
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
Expand All @@ -217,7 +231,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch

# Make sure that the outputs of the individual layers are the same as the ensemble output
for i, mlp in enumerate(mlps):

if more_batch_dim:
individual_output = mlp(input_tensor[:, i])
ensemble_output_i = ensemble_output[:, i]
Expand All @@ -228,8 +241,6 @@ def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch
ensemble_output_i = ensemble_output_i.detach().numpy()
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)



def test_ensemble_mlp(self):
# more_batch_dim=0
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
Expand All @@ -247,10 +258,16 @@ def test_ensemble_mlp(self):
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)

# Test `last_layer_is_readout`
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True)
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True)
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True)


if __name__ == '__main__':
self.check_ensemble_mlp(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True
)
self.check_ensemble_mlp(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True
)
self.check_ensemble_mlp(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True
)


if __name__ == "__main__":
ut.main()

0 comments on commit e0f841a

Please sign in to comment.