Skip to content

Commit

Permalink
Added tests for mu_readout
Browse files Browse the repository at this point in the history
  • Loading branch information
DomInvivo committed Dec 13, 2023
1 parent e0f841a commit dd52ca5
Showing 1 changed file with 53 additions and 6 deletions.
59 changes: 53 additions & 6 deletions tests/test_ensemble_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch.nn import Linear
import unittest as ut

from graphium.nn.base_layers import FCLayer, MLP
from graphium.nn.base_layers import FCLayer, MLP, MuReadoutGraphium
from graphium.nn.ensemble_layers import (
EnsembleLinear,
EnsembleFCLayer,
Expand All @@ -19,15 +19,27 @@
class test_Ensemble_Layers(ut.TestCase):
# for drop_rate=0.5, test if the output shape is correct
def check_ensemble_linear(
self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim: int
self,
in_dim: int,
out_dim: int,
num_ensemble: int,
batch_size: int,
more_batch_dim: int,
use_mureadout=False,
):
msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"

# Create EnsembleLinear instance
ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble)
if use_mureadout:
# Create EnsembleMuReadoutGraphium instance
ensemble_linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble)
# Create equivalent separate Linear layers with synchronized weights and biases
linear_layers = [MuReadoutGraphium(in_dim, out_dim) for _ in range(num_ensemble)]
else:
# Create EnsembleLinear instance
ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble)
# Create equivalent separate Linear layers with synchronized weights and biases
linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)]

# Create equivalent separate Linear layers with synchronized weights and biases
linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)]
for i, linear_layer in enumerate(linear_layers):
linear_layer.weight.data = ensemble_linear.weight.data[i]
if ensemble_linear.bias is not None:
Expand Down Expand Up @@ -87,6 +99,41 @@ def test_ensemble_linear(self):
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7)
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)

def test_ensemble_mureadout_graphium(self):
# Test `use_mureadout`
# more_batch_dim=0
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, use_mureadout=True
)
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0, use_mureadout=True
)
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0, use_mureadout=True
)

# more_batch_dim=1
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, use_mureadout=True
)
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1, use_mureadout=True
)
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1, use_mureadout=True
)

# more_batch_dim=7
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, use_mureadout=True
)
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7, use_mureadout=True
)
self.check_ensemble_linear(
in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7, use_mureadout=True
)

# for drop_rate=0.5, test if the output shape is correct
def check_ensemble_fclayer(
self,
Expand Down

0 comments on commit dd52ca5

Please sign in to comment.