Skip to content

Commit 015373f

Browse files
committed
Tested the ensemble layers
1 parent 44f9eba commit 015373f

File tree

4 files changed

+274
-14
lines changed

4 files changed

+274
-14
lines changed

graphium/nn/base_layers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ class MuReadoutGraphium(MuReadout):
236236

237237
def __init__(self, in_features, *args, **kwargs):
238238
super().__init__(in_features, *args, **kwargs)
239-
self.base_width = in_features
239+
self._base_width = in_features
240240

241241
@property
242242
def absolute_width(self):
@@ -442,7 +442,7 @@ def __init__(
442442
in_dim: int,
443443
hidden_dims: Union[Iterable[int], int],
444444
out_dim: int,
445-
depth: int,
445+
depth: Optional[int] = None,
446446
activation: Union[str, Callable] = "relu",
447447
last_activation: Union[str, Callable] = "none",
448448
dropout: float = 0.0,
@@ -530,12 +530,12 @@ def __init__(
530530

531531
all_dims = [in_dim] + self.hidden_dims + [out_dim]
532532
fully_connected = []
533-
if depth == 0:
533+
if self.depth == 0:
534534
self.fully_connected = None
535535
return
536536
else:
537-
for ii in range(depth):
538-
if ii < (depth - 1):
537+
for ii in range(self.depth):
538+
if ii < (self.depth - 1):
539539
# Define the parameters for all intermediate layers
540540
this_activation = activation
541541
this_normalization = normalization
@@ -551,7 +551,7 @@ def __init__(
551551
if constant_droppath_rate:
552552
this_drop_rate = droppath_rate
553553
else:
554-
this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, depth)
554+
this_drop_rate = DropPath.get_stochastic_drop_rate(droppath_rate, ii, self.depth)
555555

556556
# Add a fully-connected layer
557557
fully_connected.append(

graphium/nn/ensemble_layers.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ def reset_parameters(self):
5050
"""
5151
Reset the parameters of the linear layer using the `init_fn`.
5252
"""
53+
set_base_shapes(self, None, rescale_params=False) # Set the shapes of the tensors, useful for mup
5354
# Initialize weight using the provided initialization function
5455
self.init_fn(self.weight)
5556

@@ -169,7 +170,7 @@ def __init__(
169170
in_dim, out_dim, num_ensemble=num_ensemble, bias=bias, init_fn=init_fn
170171
)
171172
else:
172-
self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, bias=bias)
173+
self.linear = EnsembleMuReadoutGraphium(in_dim, out_dim, num_ensemble=num_ensemble, bias=bias)
173174

174175
self.reset_parameters()
175176

@@ -202,9 +203,10 @@ def __init__(
202203
readout_zero_init=False,
203204
output_mult=1.0,
204205
):
206+
self.in_dim = in_dim
205207
self.output_mult = output_mult
206208
self.readout_zero_init = readout_zero_init
207-
self.base_width = in_dim
209+
self._base_width = in_dim
208210
super().__init__(
209211
in_dim=in_dim,
210212
out_dim=out_dim,
@@ -254,7 +256,7 @@ def forward(self, x):
254256

255257
@property
256258
def absolute_width(self):
257-
return float(self.in_features)
259+
return float(self.in_dim)
258260

259261
@property
260262
def base_width(self):
@@ -279,8 +281,8 @@ def __init__(
279281
in_dim: int,
280282
hidden_dims: Union[Iterable[int], int],
281283
out_dim: int,
282-
depth: int,
283284
num_ensemble: int,
285+
depth: Optional[int] = None,
284286
reduction: Optional[Union[str, Callable]] = "none",
285287
activation: Union[str, Callable] = "relu",
286288
last_activation: Union[str, Callable] = "none",
@@ -304,13 +306,13 @@ def __init__(
304306
or a list of dimensions in the hidden layers.
305307
out_dim:
306308
Output dimension of the MLP.
309+
num_ensemble:
310+
Number of MLPs that run in parallel.
307311
depth:
308312
If `hidden_dims` is an integer, `depth` is 1 + the number of
309313
hidden layers to use.
310314
If `hidden_dims` is a list, then
311315
`depth` must be `None` or equal to `len(hidden_dims) + 1`
312-
num_ensemble:
313-
Number of MLPs that run in parallel.
314316
reduction:
315317
Reduction to use at the end of the MLP. Choices:
316318
@@ -358,7 +360,6 @@ def __init__(
358360
hidden_dims=hidden_dims,
359361
out_dim=out_dim,
360362
depth=depth,
361-
num_ensemble=num_ensemble,
362363
activation=activation,
363364
last_activation=last_activation,
364365
dropout=dropout,
@@ -369,6 +370,8 @@ def __init__(
369370
last_layer_is_readout=last_layer_is_readout,
370371
droppath_rate=droppath_rate,
371372
constant_droppath_rate=constant_droppath_rate,
373+
fc_layer=EnsembleFCLayer,
374+
fc_layer_kwargs={"num_ensemble": num_ensemble},
372375
)
373376

374377
self.reduction = self._parse_reduction(reduction)

graphium/utils/spaces.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torchmetrics.functional as TorchMetrics
55

66
import graphium.nn.base_layers as BaseLayers
7+
import graphium.nn.ensemble_layers as EnsembleLayers
78
from graphium.nn.architectures import FeedForwardNN, FeedForwardPyg, TaskHeads
89
import graphium.utils.custom_lr as CustomLR
910
import graphium.data.datamodule as Datamodules
@@ -28,7 +29,7 @@
2829
}
2930

3031
ENSEMBLE_FC_LAYERS_DICT = {
31-
"ens-fc": BaseLayers.EnsembleFCLayer,
32+
"ens-fc": EnsembleLayers.EnsembleFCLayer,
3233
}
3334

3435
PYG_LAYERS_DICT = {

tests/test_ensemble_layers.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
"""
2+
Unit tests for the different layers of graphium/nn/ensemble_layers
3+
"""
4+
5+
import numpy as np
6+
import torch
7+
from torch.nn import Linear
8+
import unittest as ut
9+
10+
from graphium.nn.base_layers import FCLayer, MLP
11+
from graphium.nn.ensemble_layers import EnsembleLinear, EnsembleFCLayer, EnsembleMLP, EnsembleMuReadoutGraphium
12+
13+
14+
class test_Ensemble_Layers(ut.TestCase):
15+
16+
# for drop_rate=0.5, test if the output shape is correct
17+
def check_ensemble_linear(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int):
18+
19+
msg = f"Testing EnsembleLinear with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"
20+
21+
# Create EnsembleLinear instance
22+
ensemble_linear = EnsembleLinear(in_dim, out_dim, num_ensemble)
23+
24+
# Create equivalent separate Linear layers with synchronized weights and biases
25+
linear_layers = [Linear(in_dim, out_dim) for _ in range(num_ensemble)]
26+
for i, linear_layer in enumerate(linear_layers):
27+
linear_layer.weight.data = ensemble_linear.weight.data[i]
28+
if ensemble_linear.bias is not None:
29+
linear_layer.bias.data = ensemble_linear.bias.data[i].squeeze()
30+
31+
# Test with a sample input
32+
input_tensor = torch.randn(batch_size, in_dim)
33+
ensemble_output = ensemble_linear(input_tensor)
34+
35+
# Check for the output shape
36+
self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg)
37+
38+
# Make sure that the outputs of the individual layers are the same as the ensemble output
39+
for i, linear_layer in enumerate(linear_layers):
40+
41+
individual_output = linear_layer(input_tensor)
42+
individual_output = individual_output.detach().numpy()
43+
ensemble_output_i = ensemble_output[i].detach().numpy()
44+
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
45+
46+
47+
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
48+
if more_batch_dim:
49+
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
50+
input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim)
51+
else:
52+
out_shape = (num_ensemble, batch_size, out_dim)
53+
input_tensor = torch.randn(num_ensemble, batch_size, in_dim)
54+
ensemble_output = ensemble_linear(input_tensor)
55+
56+
# Check for the output shape
57+
self.assertEqual(ensemble_output.shape, out_shape, msg=msg)
58+
59+
# Make sure that the outputs of the individual layers are the same as the ensemble output
60+
for i, linear_layer in enumerate(linear_layers):
61+
62+
if more_batch_dim:
63+
individual_output = linear_layer(input_tensor[:, i])
64+
ensemble_output_i = ensemble_output[:, i]
65+
else:
66+
individual_output = linear_layer(input_tensor[i])
67+
ensemble_output_i = ensemble_output[i]
68+
individual_output = individual_output.detach().numpy()
69+
ensemble_output_i = ensemble_output_i.detach().numpy()
70+
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
71+
72+
73+
74+
def test_ensemble_linear(self):
75+
# more_batch_dim=0
76+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
77+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0)
78+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0)
79+
80+
# more_batch_dim=1
81+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1)
82+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1)
83+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1)
84+
85+
# more_batch_dim=7
86+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7)
87+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7)
88+
self.check_ensemble_linear(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)
89+
90+
91+
# for drop_rate=0.5, test if the output shape is correct
92+
def check_ensemble_fclayer(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, is_readout_layer=False):
93+
94+
msg = f"Testing EnsembleFCLayer with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"
95+
96+
# Create EnsembleFCLayer instance
97+
ensemble_fclayer = EnsembleFCLayer(in_dim, out_dim, num_ensemble, is_readout_layer=is_readout_layer)
98+
99+
# Create equivalent separate FCLayer layers with synchronized weights and biases
100+
fc_layers = [FCLayer(in_dim, out_dim, is_readout_layer=is_readout_layer) for _ in range(num_ensemble)]
101+
for i, fc_layer in enumerate(fc_layers):
102+
fc_layer.linear.weight.data = ensemble_fclayer.linear.weight.data[i]
103+
if ensemble_fclayer.bias is not None:
104+
fc_layer.linear.bias.data = ensemble_fclayer.linear.bias.data[i].squeeze()
105+
106+
# Test with a sample input
107+
input_tensor = torch.randn(batch_size, in_dim)
108+
ensemble_output = ensemble_fclayer(input_tensor)
109+
110+
# Check for the output shape
111+
self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg)
112+
113+
# Make sure that the outputs of the individual layers are the same as the ensemble output
114+
for i, fc_layer in enumerate(fc_layers):
115+
116+
individual_output = fc_layer(input_tensor)
117+
individual_output = individual_output.detach().numpy()
118+
ensemble_output_i = ensemble_output[i].detach().numpy()
119+
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
120+
121+
122+
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
123+
if more_batch_dim:
124+
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
125+
input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim)
126+
else:
127+
out_shape = (num_ensemble, batch_size, out_dim)
128+
input_tensor = torch.randn(num_ensemble, batch_size, in_dim)
129+
ensemble_output = ensemble_fclayer(input_tensor)
130+
131+
# Check for the output shape
132+
self.assertEqual(ensemble_output.shape, out_shape, msg=msg)
133+
134+
# Make sure that the outputs of the individual layers are the same as the ensemble output
135+
for i, fc_layer in enumerate(fc_layers):
136+
137+
if more_batch_dim:
138+
individual_output = fc_layer(input_tensor[:, i])
139+
ensemble_output_i = ensemble_output[:, i]
140+
else:
141+
individual_output = fc_layer(input_tensor[i])
142+
ensemble_output_i = ensemble_output[i]
143+
individual_output = individual_output.detach().numpy()
144+
ensemble_output_i = ensemble_output_i.detach().numpy()
145+
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
146+
147+
148+
149+
def test_ensemble_fclayer(self):
150+
# more_batch_dim=0
151+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
152+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0)
153+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0)
154+
155+
# more_batch_dim=1
156+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1)
157+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1)
158+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1)
159+
160+
# more_batch_dim=7
161+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7)
162+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7)
163+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)
164+
165+
# Test `is_readout_layer`
166+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, is_readout_layer=True)
167+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, is_readout_layer=True)
168+
self.check_ensemble_fclayer(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, is_readout_layer=True)
169+
170+
171+
172+
173+
# for drop_rate=0.5, test if the output shape is correct
174+
def check_ensemble_mlp(self, in_dim: int, out_dim: int, num_ensemble: int, batch_size: int, more_batch_dim:int, last_layer_is_readout=False):
175+
176+
msg = f"Testing EnsembleMLP with in_dim={in_dim}, out_dim={out_dim}, num_ensemble={num_ensemble}, batch_size={batch_size}, more_batch_dim={more_batch_dim}"
177+
178+
# Create EnsembleMLP instance
179+
hidden_dims = [17, 17, 17]
180+
ensemble_mlp = EnsembleMLP(in_dim, hidden_dims, out_dim, num_ensemble, last_layer_is_readout=last_layer_is_readout)
181+
182+
# Create equivalent separate MLP layers with synchronized weights and biases
183+
mlps = [MLP(in_dim, hidden_dims, out_dim, last_layer_is_readout=last_layer_is_readout) for _ in range(num_ensemble)]
184+
for i, mlp in enumerate(mlps):
185+
for j, layer in enumerate(mlp.fully_connected):
186+
layer.linear.weight.data = ensemble_mlp.fully_connected[j].linear.weight.data[i]
187+
if layer.bias is not None:
188+
layer.linear.bias.data = ensemble_mlp.fully_connected[j].linear.bias.data[i].squeeze()
189+
190+
# Test with a sample input
191+
input_tensor = torch.randn(batch_size, in_dim)
192+
ensemble_output = ensemble_mlp(input_tensor)
193+
194+
# Check for the output shape
195+
self.assertEqual(ensemble_output.shape, (num_ensemble, batch_size, out_dim), msg=msg)
196+
197+
# Make sure that the outputs of the individual layers are the same as the ensemble output
198+
for i, mlp in enumerate(mlps):
199+
200+
individual_output = mlp(input_tensor)
201+
individual_output = individual_output.detach().numpy()
202+
ensemble_output_i = ensemble_output[i].detach().numpy()
203+
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
204+
205+
206+
# Test with a sample input with the extra `num_ensemble` and `more_batch_dim` dimension
207+
if more_batch_dim:
208+
out_shape = (more_batch_dim, num_ensemble, batch_size, out_dim)
209+
input_tensor = torch.randn(more_batch_dim, num_ensemble, batch_size, in_dim)
210+
else:
211+
out_shape = (num_ensemble, batch_size, out_dim)
212+
input_tensor = torch.randn(num_ensemble, batch_size, in_dim)
213+
ensemble_output = ensemble_mlp(input_tensor)
214+
215+
# Check for the output shape
216+
self.assertEqual(ensemble_output.shape, out_shape, msg=msg)
217+
218+
# Make sure that the outputs of the individual layers are the same as the ensemble output
219+
for i, mlp in enumerate(mlps):
220+
221+
if more_batch_dim:
222+
individual_output = mlp(input_tensor[:, i])
223+
ensemble_output_i = ensemble_output[:, i]
224+
else:
225+
individual_output = mlp(input_tensor[i])
226+
ensemble_output_i = ensemble_output[i]
227+
individual_output = individual_output.detach().numpy()
228+
ensemble_output_i = ensemble_output_i.detach().numpy()
229+
np.testing.assert_allclose(ensemble_output_i, individual_output, atol=1e-5, err_msg=msg)
230+
231+
232+
233+
def test_ensemble_mlp(self):
234+
# more_batch_dim=0
235+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0)
236+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=0)
237+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=0)
238+
239+
# more_batch_dim=1
240+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1)
241+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=1)
242+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=1)
243+
244+
# more_batch_dim=7
245+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7)
246+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=1, more_batch_dim=7)
247+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=1, batch_size=13, more_batch_dim=7)
248+
249+
# Test `last_layer_is_readout`
250+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=0, last_layer_is_readout=True)
251+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=1, last_layer_is_readout=True)
252+
self.check_ensemble_mlp(in_dim=11, out_dim=5, num_ensemble=3, batch_size=13, more_batch_dim=7, last_layer_is_readout=True)
253+
254+
255+
if __name__ == '__main__':
256+
ut.main()

0 commit comments

Comments
 (0)