Skip to content

Commit 6fcec40

Browse files
authored
Merge pull request #49 from choderalab/fix-e3nn-ref-model
Fix e3nn model building
2 parents 3b08d3e + 0d610d3 commit 6fcec40

File tree

2 files changed

+90
-17
lines changed

2 files changed

+90
-17
lines changed

mtenn/conversion_utils/e3nn.py

+20-17
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,26 @@ def __init__(self, *args, model=None, **kwargs):
1717
super(E3NN, self).__init__(*args, **kwargs)
1818
self.model_parameters = kwargs
1919
else:
20-
# this will need changing to include model features of e3nn
21-
atomref = model.atomref.weight.detach().clone()
22-
model_params = (
23-
model.hidden_channels,
24-
model.num_filters,
25-
model.num_interactions,
26-
model.num_gaussians,
27-
model.cutoff,
28-
model.max_num_neighbors,
29-
model.readout,
30-
model.dipole,
31-
model.mean,
32-
model.std,
33-
atomref,
34-
)
35-
super(E3NN, self).__init__(*model_params)
36-
self.model_parameters = model_params
20+
model_kwargs = {
21+
"irreps_in": model.irreps_in,
22+
"irreps_hidden": model.irreps_hidden,
23+
"irreps_out": model.irreps_out,
24+
"irreps_node_attr": model.irreps_node_attr,
25+
"irreps_edge_attr": model.irreps_edge_attr,
26+
"layers": len(model.layers) - 1,
27+
"max_radius": model.max_radius,
28+
"number_of_basis": model.number_of_basis,
29+
"num_nodes": model.num_nodes,
30+
"reduce_output": model.reduce_output,
31+
}
32+
# These need a bit of work to get
33+
# Use last layer bc guaranteed to be present and is just a Convolution
34+
conv = model.layers[-1]
35+
model_kwargs["radial_layers"] = len(conv.fc.hs) - 2
36+
model_kwargs["radial_neurons"] = conv.fc.hs[1]
37+
model_kwargs["num_neighbors"] = conv.num_neighbors
38+
super(E3NN, self).__init__(**model_kwargs)
39+
self.model_parameters = model_kwargs
3740
self.load_state_dict(model.state_dict())
3841

3942
def forward(self, data):

mtenn/tests/test_e3nn.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import pytest
2+
3+
from e3nn.nn.models.gate_points_2101 import Network
4+
from e3nn.o3 import Irreps
5+
from mtenn.conversion_utils.e3nn import E3NN
6+
7+
8+
@pytest.fixture
9+
def e3nn_kwargs():
10+
return {
11+
"irreps_in": "5x0e+2x1o",
12+
"irreps_hidden": "10x0e+10x0o+1o+1e",
13+
"irreps_out": "0e",
14+
"irreps_node_attr": "0e",
15+
"irreps_edge_attr": Irreps.spherical_harmonics(2),
16+
"layers": 5,
17+
"max_radius": 10,
18+
"number_of_basis": 5,
19+
"radial_layers": 5,
20+
"radial_neurons": 32,
21+
"num_neighbors": 10,
22+
"num_nodes": 100,
23+
"reduce_output": True,
24+
}
25+
26+
27+
def test_build_e3nn_directly_kwargs(e3nn_kwargs):
28+
model = E3NN(**e3nn_kwargs)
29+
30+
# Directly stored parameters
31+
assert model.irreps_in == Irreps(e3nn_kwargs["irreps_in"])
32+
assert model.irreps_hidden == Irreps(e3nn_kwargs["irreps_hidden"])
33+
assert model.irreps_out == Irreps(e3nn_kwargs["irreps_out"])
34+
assert model.irreps_node_attr == Irreps(e3nn_kwargs["irreps_node_attr"])
35+
assert model.irreps_edge_attr == Irreps(e3nn_kwargs["irreps_edge_attr"])
36+
assert len(model.layers) == e3nn_kwargs["layers"] + 1
37+
assert model.max_radius == e3nn_kwargs["max_radius"]
38+
assert model.number_of_basis == e3nn_kwargs["number_of_basis"]
39+
assert model.num_nodes == e3nn_kwargs["num_nodes"]
40+
assert model.reduce_output == e3nn_kwargs["reduce_output"]
41+
42+
# Indirect ones
43+
conv = model.layers[-1]
44+
assert len(conv.fc.hs) - 2 == e3nn_kwargs["radial_layers"]
45+
assert conv.fc.hs[1] == e3nn_kwargs["radial_neurons"]
46+
assert conv.num_neighbors == e3nn_kwargs["num_neighbors"]
47+
48+
49+
def test_build_e3nn_from_e3nn_network(e3nn_kwargs):
50+
ref_model = Network(**e3nn_kwargs)
51+
model = E3NN(model=ref_model)
52+
53+
# Directly stored parameters
54+
assert model.irreps_in == ref_model.irreps_in
55+
assert model.irreps_hidden == ref_model.irreps_hidden
56+
assert model.irreps_out == ref_model.irreps_out
57+
assert model.irreps_node_attr == ref_model.irreps_node_attr
58+
assert model.irreps_edge_attr == ref_model.irreps_edge_attr
59+
assert len(model.layers) == len(ref_model.layers)
60+
assert model.max_radius == ref_model.max_radius
61+
assert model.number_of_basis == ref_model.number_of_basis
62+
assert model.num_nodes == ref_model.num_nodes
63+
assert model.reduce_output == ref_model.reduce_output
64+
65+
# Indirect ones
66+
ref_conv = ref_model.layers[-1]
67+
conv = model.layers[-1]
68+
assert len(conv.fc.hs) == len(ref_conv.fc.hs)
69+
assert conv.fc.hs[1] == ref_conv.fc.hs[1]
70+
assert conv.num_neighbors == ref_conv.num_neighbors

0 commit comments

Comments
 (0)