|
| 1 | +import pytest |
| 2 | + |
| 3 | +from e3nn.nn.models.gate_points_2101 import Network |
| 4 | +from e3nn.o3 import Irreps |
| 5 | +from mtenn.conversion_utils.e3nn import E3NN |
| 6 | + |
| 7 | + |
| 8 | +@pytest.fixture |
| 9 | +def e3nn_kwargs(): |
| 10 | + return { |
| 11 | + "irreps_in": "5x0e+2x1o", |
| 12 | + "irreps_hidden": "10x0e+10x0o+1o+1e", |
| 13 | + "irreps_out": "0e", |
| 14 | + "irreps_node_attr": "0e", |
| 15 | + "irreps_edge_attr": Irreps.spherical_harmonics(2), |
| 16 | + "layers": 5, |
| 17 | + "max_radius": 10, |
| 18 | + "number_of_basis": 5, |
| 19 | + "radial_layers": 5, |
| 20 | + "radial_neurons": 32, |
| 21 | + "num_neighbors": 10, |
| 22 | + "num_nodes": 100, |
| 23 | + "reduce_output": True, |
| 24 | + } |
| 25 | + |
| 26 | + |
| 27 | +def test_build_e3nn_directly_kwargs(e3nn_kwargs): |
| 28 | + model = E3NN(**e3nn_kwargs) |
| 29 | + |
| 30 | + # Directly stored parameters |
| 31 | + assert model.irreps_in == Irreps(e3nn_kwargs["irreps_in"]) |
| 32 | + assert model.irreps_hidden == Irreps(e3nn_kwargs["irreps_hidden"]) |
| 33 | + assert model.irreps_out == Irreps(e3nn_kwargs["irreps_out"]) |
| 34 | + assert model.irreps_node_attr == Irreps(e3nn_kwargs["irreps_node_attr"]) |
| 35 | + assert model.irreps_edge_attr == Irreps(e3nn_kwargs["irreps_edge_attr"]) |
| 36 | + assert len(model.layers) == e3nn_kwargs["layers"] + 1 |
| 37 | + assert model.max_radius == e3nn_kwargs["max_radius"] |
| 38 | + assert model.number_of_basis == e3nn_kwargs["number_of_basis"] |
| 39 | + assert model.num_nodes == e3nn_kwargs["num_nodes"] |
| 40 | + assert model.reduce_output == e3nn_kwargs["reduce_output"] |
| 41 | + |
| 42 | + # Indirect ones |
| 43 | + conv = model.layers[-1] |
| 44 | + assert len(conv.fc.hs) - 2 == e3nn_kwargs["radial_layers"] |
| 45 | + assert conv.fc.hs[1] == e3nn_kwargs["radial_neurons"] |
| 46 | + assert conv.num_neighbors == e3nn_kwargs["num_neighbors"] |
| 47 | + |
| 48 | + |
| 49 | +def test_build_e3nn_from_e3nn_network(e3nn_kwargs): |
| 50 | + ref_model = Network(**e3nn_kwargs) |
| 51 | + model = E3NN(model=ref_model) |
| 52 | + |
| 53 | + # Directly stored parameters |
| 54 | + assert model.irreps_in == ref_model.irreps_in |
| 55 | + assert model.irreps_hidden == ref_model.irreps_hidden |
| 56 | + assert model.irreps_out == ref_model.irreps_out |
| 57 | + assert model.irreps_node_attr == ref_model.irreps_node_attr |
| 58 | + assert model.irreps_edge_attr == ref_model.irreps_edge_attr |
| 59 | + assert len(model.layers) == len(ref_model.layers) |
| 60 | + assert model.max_radius == ref_model.max_radius |
| 61 | + assert model.number_of_basis == ref_model.number_of_basis |
| 62 | + assert model.num_nodes == ref_model.num_nodes |
| 63 | + assert model.reduce_output == ref_model.reduce_output |
| 64 | + |
| 65 | + # Indirect ones |
| 66 | + ref_conv = ref_model.layers[-1] |
| 67 | + conv = model.layers[-1] |
| 68 | + assert len(conv.fc.hs) == len(ref_conv.fc.hs) |
| 69 | + assert conv.fc.hs[1] == ref_conv.fc.hs[1] |
| 70 | + assert conv.num_neighbors == ref_conv.num_neighbors |
0 commit comments