From dd52ca5122746eccfb2ab2188f3c77e11d75642a Mon Sep 17 00:00:00 2001 From: DomInvivo Date: Wed, 13 Dec 2023 17:42:48 -0500 Subject: [PATCH] Added tests for mu_readout --- tests/test_ensemble_layers.py | 59 +++++++++++++++++++++++++++++++---- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/tests/test_ensemble_layers.py b/tests/test_ensemble_layers.py index ff96e0fad..66de95e48 100644 --- a/tests/test_ensemble_layers.py +++ b/tests/test_ensemble_layers.py @@ -7,7 +7,7 @@ from torch.nn import Linear import unittest as ut -from graphium.nn.base_layers import FCLayer, MLP +from graphium.nn.base_layers import FCLayer, MLP, MuReadoutGraphium from graphium.nn.ensemble_layers import ( EnsembleLinear, EnsembleFCLayer, @@ -19,15 +19,27 @@ class test_Ensemble_Layers(ut.TestCase): # for drop_rate=0.5, test if the output shape is correct def check_ensemble_linear( - self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim: int + self, + in_dim: int, + out_dim: int, + num_ensemble: int, + batch_size: int, + more_batch_dim: int, + use_mureadout=False, ): msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}" - # Create EnsembleLinear instance - ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble) + if use_mureadout: + # Create EnsembleMuReadoutGraphium instance + ensemble_linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble) + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [MuReadoutGraphium(in_dim, out_dim) for _ in range(num_ensemble)] + else: + # Create EnsembleLinear instance + ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble) + # Create equivalent separate Linear layers with synchronized weights and biases + linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)] - # Create equivalent separate Linear layers with synchronized weights and biases - linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)] for i, linear_layer in enumerate(linear_layers): linear_layer.weight.data = ensemble_linear.weight.data[i] if ensemble_linear.bias is not None: @@ -87,6 +99,41 @@ def test_ensemble_linear(self): self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7) self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7) + def test_ensemble_mureadout_graphium(self): + # Test `use_mureadout` + # more_batch_dim=0 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0, use_mureadout=True + ) + + # more_batch_dim=1 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1, use_mureadout=True + ) + + # more_batch_dim=7 + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7, use_mureadout=True + ) + self.check_ensemble_linear( + in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7, use_mureadout=True + ) + # for drop_rate=0.5, test if the output shape is correct def check_ensemble_fclayer( self,