|
4 | 4 | from atomate2.forcefields.utils import ase_calculator
|
5 | 5 |
|
6 | 6 |
|
7 |
| -@pytest.mark.parametrize(("force_field"), [mlff.value for mlff in MLFF]) |
8 |
| -def test_mlff(force_field: str): |
9 |
| - mlff = MLFF(force_field) |
| 7 | +@pytest.mark.parametrize("mlff", MLFF) |
| 8 | +def test_mlff(mlff: MLFF): |
10 | 9 | assert mlff == MLFF(str(mlff)) == MLFF(str(mlff).split(".")[-1])
|
11 | 10 |
|
12 | 11 |
|
13 |
| -@pytest.mark.parametrize(("force_field"), ["CHGNet", "MACE"]) |
14 |
| -def test_ext_load(force_field: str): |
| 12 | +@pytest.mark.parametrize("mlff", ["CHGNet", "MACE", MLFF.MatterSim, MLFF.SevenNet]) |
| 13 | +def test_ext_load(mlff: str): |
15 | 14 | decode_dict = {
|
16 | 15 | "CHGNet": {"@module": "chgnet.model.dynamics", "@callable": "CHGNetCalculator"},
|
17 | 16 | "MACE": {"@module": "mace.calculators", "@callable": "mace_mp"},
|
18 |
| - }[force_field] |
| 17 | + MLFF.MatterSim: { |
| 18 | + "@module": "mattersim.forcefield", |
| 19 | + "@callable": "MatterSimCalculator", |
| 20 | + }, |
| 21 | + MLFF.SevenNet: { |
| 22 | + "@module": "sevenn.sevennet_calculator", |
| 23 | + "@callable": "SevenNetCalculator", |
| 24 | + }, |
| 25 | + }[mlff] |
19 | 26 | calc_from_decode = ase_calculator(decode_dict)
|
20 |
| - calc_from_preset = ase_calculator(str(MLFF(force_field))) |
| 27 | + calc_from_preset = ase_calculator(str(MLFF(mlff))) |
21 | 28 | assert type(calc_from_decode) is type(calc_from_preset)
|
22 | 29 | assert calc_from_decode.name == calc_from_preset.name
|
23 | 30 | assert calc_from_decode.parameters == calc_from_preset.parameters == {}
|
|
0 commit comments