Skip to content

Commit ccdf013

Browse files
Merge pull request #1127 from yaoyi92/mol_for_forcefield
Molecule for forcefield
2 parents c40a394 + 2a47376 commit ccdf013

File tree

9 files changed

+122
-47
lines changed

9 files changed

+122
-47
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,4 +79,5 @@ docs/reference/atomate2.*
7979
.ipynb_checkpoints
8080
.aider*
8181

82-
tests/test_data/forcefields/deepmd_graph.pb
82+
# deepmd-kit files
83+
**/*.pb

src/atomate2/forcefields/jobs.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
from collections.abc import Callable
1818
from pathlib import Path
1919

20-
from pymatgen.core.structure import Structure
20+
from pymatgen.core.structure import Molecule, Structure
21+
22+
from atomate2.forcefields.schemas import ForceFieldMoleculeTaskDocument
2123

2224
logger = logging.getLogger(__name__)
2325

@@ -29,7 +31,9 @@ def forcefield_job(method: Callable) -> job:
2931
This is a thin wrapper around :obj:`~jobflow.core.job.Job` that configures common
3032
settings for all forcefield jobs. For example, it ensures that large data objects
3133
(currently only trajectories) are all stored in the atomate2 data store.
32-
It also configures the output schema to be a ForceFieldTaskDocument :obj:`.TaskDoc`.
34+
It also configures the output schema to be a
35+
ForceFieldTaskDocument :obj:`.TaskDoc`. or
36+
ForceFieldMoleculeTaskDocument :obj:`.TaskDoc`.
3337
3438
Any makers that return forcefield jobs (not flows) should decorate the
3539
``make`` method with @forcefield_job. For example:
@@ -53,9 +57,7 @@ def make(structure):
5357
callable
5458
A decorated version of the make function that will generate forcefield jobs.
5559
"""
56-
return job(
57-
method, data=_FORCEFIELD_DATA_OBJECTS, output_schema=ForceFieldTaskDocument
58-
)
60+
return job(method, data=_FORCEFIELD_DATA_OBJECTS)
5961

6062

6163
@dataclass
@@ -99,7 +101,8 @@ class ForceFieldRelaxMaker(ForceFieldMixin, AseRelaxMaker):
99101
tags : list[str] or None
100102
A list of tags for the task.
101103
task_document_kwargs : dict (deprecated)
102-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
104+
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()` or
105+
:obj: `ForceFieldMoleculeTaskDocument`.
103106
"""
104107

105108
name: str = "Force field relax"
@@ -115,15 +118,15 @@ class ForceFieldRelaxMaker(ForceFieldMixin, AseRelaxMaker):
115118

116119
@forcefield_job
117120
def make(
118-
self, structure: Structure, prev_dir: str | Path | None = None
119-
) -> ForceFieldTaskDocument:
121+
self, structure: Molecule | Structure, prev_dir: str | Path | None = None
122+
) -> ForceFieldTaskDocument | ForceFieldMoleculeTaskDocument:
120123
"""
121124
Perform a relaxation of a structure using a force field.
122125
123126
Parameters
124127
----------
125-
structure: .Structure
126-
pymatgen structure.
128+
structure: .Structure or Molecule
129+
pymatgen structure or molecule.
127130
prev_dir : str or Path or None
128131
A previous calculation directory to copy output files from. Unused, just
129132
added to match the method signature of other makers.
@@ -172,7 +175,8 @@ class ForceFieldStaticMaker(ForceFieldRelaxMaker):
172175
calculator_kwargs : dict
173176
Keyword arguments that will get passed to the ASE calculator.
174177
task_document_kwargs : dict (deprecated)
175-
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()`.
178+
Additional keyword args passed to :obj:`.ForceFieldTaskDocument()` or
179+
:obj: `ForceFieldMoleculeTaskDocument`.
176180
"""
177181

178182
name: str = "Force field static"

src/atomate2/forcefields/md.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
if TYPE_CHECKING:
1616
from pathlib import Path
1717

18-
from pymatgen.core.structure import Structure
18+
from pymatgen.core.structure import Molecule, Structure
19+
20+
from atomate2.forcefields.schemas import ForceFieldMoleculeTaskDocument
1921

2022

2123
@dataclass
@@ -104,19 +106,18 @@ class ForceFieldMDMaker(ForceFieldMixin, AseMDMaker):
104106

105107
@job(
106108
data=[*_FORCEFIELD_DATA_OBJECTS, "ionic_steps"],
107-
output_schema=ForceFieldTaskDocument,
108109
)
109110
def make(
110111
self,
111-
structure: Structure,
112+
structure: Molecule | Structure,
112113
prev_dir: str | Path | None = None,
113-
) -> ForceFieldTaskDocument:
114+
) -> ForceFieldTaskDocument | ForceFieldMoleculeTaskDocument:
114115
"""
115116
Perform MD on a structure using forcefields and jobflow.
116117
117118
Parameters
118119
----------
119-
structure: .Structure
120+
structure: .Structure or Molecule
120121
pymatgen structure.
121122
prev_dir : str or Path or None
122123
A previous calculation directory to copy output files from. Unused, just

src/atomate2/forcefields/schemas.py

Lines changed: 58 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,28 @@
22

33
from __future__ import annotations
44

5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

77
from emmet.core.types.enums import StoreTrajectoryOption
8-
from pydantic import Field
9-
10-
from atomate2.ase.schemas import AseObject, AseResult, AseStructureTaskDoc, AseTaskDoc
8+
from pydantic import BaseModel, Field
9+
from pymatgen.core import Molecule
10+
11+
from atomate2.ase.schemas import (
12+
AseMoleculeTaskDoc,
13+
AseObject,
14+
AseResult,
15+
AseStructureTaskDoc,
16+
AseTaskDoc,
17+
_task_doc_translation_keys,
18+
)
1119
from atomate2.forcefields import MLFF
1220

21+
if TYPE_CHECKING:
22+
from typing_extensions import Self
23+
1324

14-
class ForceFieldTaskDocument(AseStructureTaskDoc):
15-
"""Document containing information on structure manipulation using a force field."""
25+
class ForceFieldMeta(BaseModel):
26+
"""Add metadata to forcefield output documents."""
1627

1728
forcefield_name: str | None = Field(
1829
None,
@@ -42,6 +53,40 @@ class ForceFieldTaskDocument(AseStructureTaskDoc):
4253
),
4354
)
4455

56+
@property
57+
def forcefield_objects(self) -> dict[AseObject, Any] | None:
58+
"""Alias `objects` attr for backwards compatibility."""
59+
return self.objects
60+
61+
62+
class ForceFieldMoleculeTaskDocument(AseMoleculeTaskDoc, ForceFieldMeta):
63+
"""Document containing information on molecule manipulation using a force field."""
64+
65+
@classmethod
66+
def from_ase_task_doc(
67+
cls, ase_task_doc: AseTaskDoc, **task_document_kwargs
68+
) -> Self:
69+
"""Create a ForceFieldMoleculeTaskDocument from an AseTaskDoc.
70+
71+
Parameters
72+
----------
73+
ase_task_doc : AseTaskDoc
74+
Task doc for the calculation
75+
task_document_kwargs : dict
76+
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
77+
"""
78+
task_document_kwargs.update(
79+
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
80+
structure=ase_task_doc.mol_or_struct,
81+
)
82+
return cls.from_molecule(
83+
meta_molecule=ase_task_doc.mol_or_struct, **task_document_kwargs
84+
)
85+
86+
87+
class ForceFieldTaskDocument(AseStructureTaskDoc, ForceFieldMeta):
88+
"""Document containing information on atomistic manipulation using a force field."""
89+
4590
@classmethod
4691
def from_ase_compatible_result(
4792
cls,
@@ -62,8 +107,8 @@ def from_ase_compatible_result(
62107
store_trajectory: StoreTrajectoryOption = StoreTrajectoryOption.NO,
63108
tags: list[str] | None = None,
64109
**task_document_kwargs,
65-
) -> ForceFieldTaskDocument:
66-
"""Create an AseTaskDoc for a task that has ASE-compatible outputs.
110+
) -> Self | ForceFieldMoleculeTaskDocument:
111+
"""Create forcefield output for a task that has ASE-compatible outputs.
67112
68113
Parameters
69114
----------
@@ -131,9 +176,8 @@ def from_ase_compatible_result(
131176

132177
ff_kwargs["forcefield_version"] = importlib.metadata.version(pkg_name)
133178

134-
return cls.from_ase_task_doc(ase_task_doc, **ff_kwargs)
135-
136-
@property
137-
def forcefield_objects(self) -> dict[AseObject, Any] | None:
138-
"""Alias `objects` attr for backwards compatibility."""
139-
return self.objects
179+
return (
180+
ForceFieldMoleculeTaskDocument
181+
if isinstance(result.final_mol_or_struct, Molecule)
182+
else cls
183+
).from_ase_task_doc(ase_task_doc, **ff_kwargs)

tests/forcefields/conftest.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations
22

33
import hashlib
4+
import tempfile
45
import urllib.request
6+
from pathlib import Path
57
from typing import TYPE_CHECKING
68

79
import pytest
810
import torch
911
from emmet.core.utils import get_hash_blocked
1012

1113
if TYPE_CHECKING:
12-
from pathlib import Path
1314
from typing import Any
1415

1516

@@ -21,15 +22,17 @@ def pytest_runtest_setup(item: Any) -> None:
2122

2223

2324
@pytest.fixture(scope="session", autouse=True)
24-
def download_deepmd_pretrained_model(test_dir: Path) -> None:
25+
def get_deepmd_pretrained_model_path(test_dir: Path) -> Path:
2526
# Download DeepMD pretrained model from GitHub
2627
file_url = "https://raw.github.com/sliutheorygroup/UniPero/main/model/graph.pb"
27-
local_path = test_dir / "forcefields" / "deepmd_graph.pb"
28+
local_path = tempfile.NamedTemporaryFile(suffix=".pb") # noqa : SIM115
2829
ref_md5 = "2814ae7f2eb1c605dd78f2964187de40"
29-
_, http_message = urllib.request.urlretrieve(file_url, local_path) # noqa: S310
30+
_, http_message = urllib.request.urlretrieve(file_url, local_path.name) # noqa: S310
3031
if "Content-Type: text/html" in http_message:
3132
raise RuntimeError(f"Failed to download from: {file_url}")
3233

3334
# Check MD5 to ensure file integrity
34-
if (file_md5 := get_hash_blocked(local_path, hasher=hashlib.md5())) != ref_md5:
35+
if (file_md5 := get_hash_blocked(local_path.name, hasher=hashlib.md5())) != ref_md5:
3536
raise RuntimeError(f"MD5 mismatch: {file_md5} != {ref_md5}")
37+
yield Path(local_path.name)
38+
local_path.close()

tests/forcefields/test_jobs.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
import numpy as np
55
import pytest
66
from jobflow import run_locally
7-
from pymatgen.core import Structure
7+
from pymatgen.core import Molecule, Structure
88
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
99
from pytest import approx, importorskip
1010

1111
from atomate2.forcefields.jobs import ForceFieldRelaxMaker, ForceFieldStaticMaker
12-
from atomate2.forcefields.schemas import ForceFieldTaskDocument
12+
from atomate2.forcefields.schemas import (
13+
ForceFieldMoleculeTaskDocument,
14+
ForceFieldTaskDocument,
15+
)
1316

1417

1518
def test_maker_initialization():
@@ -299,9 +302,7 @@ def test_mace_relax_maker(
299302
assert output1.output.n_steps == 7
300303

301304

302-
def test_mace_mpa_0_relax_maker(
303-
si_structure: Structure,
304-
):
305+
def test_mace_mpa_0_relax_maker(si_structure: Structure, test_dir: Path, tmp_dir):
305306
job = ForceFieldRelaxMaker(
306307
force_field_name="MACE_MPA_0",
307308
steps=25,
@@ -313,12 +314,29 @@ def test_mace_mpa_0_relax_maker(
313314
# validating the outputs of the job
314315
output = responses[job.uuid][1].output
315316

317+
water_molecule = Molecule.from_file(test_dir / "molecules" / "water.xyz.gz")
318+
job_mol = ForceFieldRelaxMaker(
319+
force_field_name="MACE_MPA_0",
320+
steps=25,
321+
relax_kwargs={"fmax": 0.005},
322+
).make(water_molecule)
323+
# run the flow or job and ensure that it finished running successfully
324+
responses_mol = run_locally(job_mol, ensure_success=True)
325+
326+
# validating the outputs of the job
327+
output_mol = responses_mol[job_mol.uuid][1].output
328+
assert isinstance(output_mol, ForceFieldMoleculeTaskDocument)
329+
316330
assert output.ase_calculator_name == "MLFF.MACE_MPA_0"
317331
assert output.output.energy == pytest.approx(-10.829493522644043)
318332
assert output.output.structure.volume == pytest.approx(40.87471552602735)
319333
assert len(output.output.ionic_steps) == 4
320334
assert output.structure.volume == output.output.structure.volume
321335

336+
assert output_mol.ase_calculator_name == "MLFF.MACE_MPA_0"
337+
assert output_mol.output.energy == pytest.approx(-13.786081314086914)
338+
assert len(output_mol.output.ionic_steps) == 20
339+
322340

323341
def test_gap_static_maker(si_structure: Structure, test_dir):
324342
importorskip("quippy")
@@ -522,14 +540,16 @@ def test_nequip_relax_maker(
522540
assert final_spg_num == 99
523541

524542

525-
def test_deepmd_static_maker(sr_ti_o3_structure: Structure, test_dir: Path):
543+
def test_deepmd_static_maker(
544+
sr_ti_o3_structure: Structure, test_dir: Path, get_deepmd_pretrained_model_path
545+
):
526546
importorskip("deepmd")
527547

528548
# generate job
529549
job = ForceFieldStaticMaker(
530550
force_field_name="DeepMD",
531551
ionic_step_data=("structure", "energy"),
532-
calculator_kwargs={"model": test_dir / "forcefields" / "deepmd_graph.pb"},
552+
calculator_kwargs={"model": get_deepmd_pretrained_model_path},
533553
).make(sr_ti_o3_structure)
534554

535555
# run the flow or job and ensure that it finished running successfully
@@ -552,6 +572,7 @@ def test_deepmd_relax_maker(
552572
test_dir: Path,
553573
relax_cell: bool,
554574
fix_symmetry: bool,
575+
get_deepmd_pretrained_model_path: Path,
555576
):
556577
importorskip("deepmd")
557578
# translate one atom to ensure a small number of relaxation steps are taken
@@ -563,7 +584,7 @@ def test_deepmd_relax_maker(
563584
optimizer_kwargs={"optimizer": "BFGSLineSearch"},
564585
relax_cell=relax_cell,
565586
fix_symmetry=fix_symmetry,
566-
calculator_kwargs={"model": test_dir / "forcefields" / "deepmd_graph.pb"},
587+
calculator_kwargs={"model": get_deepmd_pretrained_model_path},
567588
).make(sr_ti_o3_structure)
568589

569590
# run the flow or job and ensure that it finished running successfully

tests/forcefields/test_md.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def test_ml_ff_md_maker(
4949
al2_au_structure,
5050
test_dir,
5151
clean_dir,
52+
get_deepmd_pretrained_model_path,
5253
):
5354
if ff_name in map(MLFF, ("Forcefield", "MACE")):
5455
return # nothing to test here, MLFF.Forcefield is just a generic placeholder
@@ -103,7 +104,7 @@ def test_ml_ff_md_maker(
103104
}
104105
unit_cell_structure = sr_ti_o3_structure.copy()
105106
elif ff_name == MLFF.DeepMD:
106-
calculator_kwargs = {"model": test_dir / "forcefields" / "deepmd_graph.pb"}
107+
calculator_kwargs = {"model": get_deepmd_pretrained_model_path}
107108
unit_cell_structure = sr_ti_o3_structure.copy()
108109

109110
structure = unit_cell_structure.to_conventional() * (2, 2, 2)

tests/forcefields/test_phonon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def test_supercell_orthorhombic(clean_dir, si_structure: Structure):
5555

5656

5757
def test_phonon_maker_initialization_with_all_mlff(
58-
si_structure: Structure, test_dir: Path
58+
si_structure: Structure, test_dir: Path, get_deepmd_pretrained_model_path: Path
5959
):
6060
"""Test PhononMaker can be initialized with all MLFF static and relax makers."""
6161

@@ -74,7 +74,7 @@ def test_phonon_maker_initialization_with_all_mlff(
7474
calc_kwargs = {
7575
MLFF.Nequip: {"model_path": f"{chk_pt_dir}/nequip/nequip_ff_sr_ti_o3.pth"},
7676
MLFF.NEP: {"model_filename": f"{test_dir}/forcefields/nep/nep.txt"},
77-
MLFF.DeepMD: {"model": test_dir / "forcefields" / "deepmd_graph.pb"},
77+
MLFF.DeepMD: {"model": get_deepmd_pretrained_model_path},
7878
}.get(mlff, {})
7979
static_maker = ForceFieldStaticMaker(
8080
name=f"{mlff} static",
86 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)