Skip to content

Commit 1b85dd4

Browse files
add missing metadata in ForceFieldTaskDoc, MACE-MPA-0, test for it
1 parent d803b9a commit 1b85dd4

File tree

5 files changed

+28
-6
lines changed

5 files changed

+28
-6
lines changed

src/atomate2/ase/schemas.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,10 @@ def from_ase_task_doc(
247247
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
248248
"""
249249
task_document_kwargs.update(
250-
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys}
250+
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
251+
structure = ase_task_doc.mol_or_struct,
251252
)
252-
task_document_kwargs["structure"] = ase_task_doc.mol_or_struct
253-
return cls(**task_document_kwargs)
253+
return cls.from_structure(meta_structure = ase_task_doc.mol_or_struct, **task_document_kwargs)
254254

255255

256256
class AseMoleculeTaskDoc(MoleculeMetadata):

src/atomate2/forcefields/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
1313
"""Names of ML force fields."""
1414

15-
MACE = "MACE"
15+
MACE = "MACE" # This is MACE-MP-0-medium
16+
MACE_MPA_0 = "MACE_MPA_0"
1617
GAP = "GAP"
1718
M3GNet = "M3GNet"
1819
CHGNet = "CHGNet"

src/atomate2/forcefields/jobs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
MLFF.NEP: {"model_filename": "nep.txt"},
3636
MLFF.GAP: {"args_str": "IP GAP", "param_filename": "gap.xml"},
3737
MLFF.MACE: {"model": "medium"},
38+
MLFF.MACE_MPA_0 : {"model": "medium-mpa-0",},
3839
}
3940

4041

src/atomate2/forcefields/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
4343
calculator = None
4444

4545
if isinstance(calculator_meta, str | MLFF) and calculator_meta in map(str, MLFF):
46-
calculator_name = MLFF(calculator_meta.split("MLFF.")[-1])
46+
47+
calculator_name = MLFF[calculator_meta.split("MLFF.")[-1]]
4748

4849
if calculator_name == MLFF.CHGNet:
4950
from chgnet.model.dynamics import CHGNetCalculator
@@ -58,7 +59,7 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
5859
potential = matgl.load_model(path)
5960
calculator = PESCalculator(potential, **kwargs)
6061

61-
elif calculator_name == MLFF.MACE:
62+
elif calculator_name in {MLFF.MACE, MLFF.MACE_MPA_0}:
6263
from mace.calculators import MACECalculator, mace_mp
6364

6465
model = kwargs.get("model")

tests/forcefields/test_jobs.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,25 @@ def test_mace_relax_maker(
318318
assert output1.output.energy == approx(-0.06772976, rel=1e-4)
319319
assert output1.output.n_steps == 7
320320

321+
def test_mace_mpa_0_relax_maker(
322+
si_structure: Structure,
323+
):
324+
job = ForceFieldRelaxMaker(
325+
force_field_name="MACE_MPA_0",
326+
steps=25,
327+
relax_kwargs={"fmax": 0.005},
328+
).make(si_structure)
329+
# run the flow or job and ensure that it finished running successfully
330+
responses = run_locally(job, ensure_success=True)
331+
332+
# validating the outputs of the job
333+
output = responses[job.uuid][1].output
334+
335+
assert output.ase_calculator_name == "MLFF.MACE_MPA_0"
336+
assert output.output.energy == pytest.approx(-10.829493522644043)
337+
assert output.output.structure.volume == pytest.approx(40.87471552602735)
338+
assert len(output.output.ionic_steps) == 4
339+
assert output.structure.volume == output.output.structure.volume
321340

322341
def test_gap_static_maker(si_structure: Structure, test_dir):
323342
importorskip("quippy")

0 commit comments

Comments
 (0)