Skip to content

Commit e07e855

Browse files
committed
add MatterSim and SevenNet to test_ext_load
1 parent e5945b4 commit e07e855

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

tests/forcefields/test_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,27 @@
44
from atomate2.forcefields.utils import ase_calculator
55

66

7-
@pytest.mark.parametrize(("force_field"), [mlff.value for mlff in MLFF])
8-
def test_mlff(force_field: str):
9-
mlff = MLFF(force_field)
7+
@pytest.mark.parametrize("mlff", MLFF)
8+
def test_mlff(mlff: MLFF):
109
assert mlff == MLFF(str(mlff)) == MLFF(str(mlff).split(".")[-1])
1110

1211

13-
@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"])
14-
def test_ext_load(force_field: str):
12+
@pytest.mark.parametrize("mlff", ["CHGNet", "MACE", MLFF.MatterSim, MLFF.SevenNet])
13+
def test_ext_load(mlff: str):
1514
decode_dict = {
1615
"CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"},
1716
"MACE": {"@module": "mace.calculators", "@callable": "mace_mp"},
18-
}[force_field]
17+
MLFF.MatterSim: {
18+
"@module": "mattersim.forcefield",
19+
"@callable": "MatterSimCalculator",
20+
},
21+
MLFF.SevenNet: {
22+
"@module": "sevenn.sevennet_calculator",
23+
"@callable": "SevenNetCalculator",
24+
},
25+
}[mlff]
1926
calc_from_decode = ase_calculator(decode_dict)
20-
calc_from_preset = ase_calculator(str(MLFF(force_field)))
27+
calc_from_preset = ase_calculator(str(MLFF(mlff)))
2128
assert type(calc_from_decode) is type(calc_from_preset)
2229
assert calc_from_decode.name == calc_from_preset.name
2330
assert calc_from_decode.parameters == calc_from_preset.parameters == {}

0 commit comments

Comments
 (0)