Skip to content

Commit e6231ad

Browse files
Pin MACE calculator version, add missing metadata to AseStructureTaskDoc (#1119)
* ensure consistent version of MACE model * add missing metadata in ForceFieldTaskDoc, MACE-MPA-0, test for it * precommit * add mace-mp-0b3 * expand from_force_field_name classmethods * precommit * add MACE_MP_0 MLFF option to supersede MACE; throw warning whenever MACE is specified without specific version * precommit
1 parent a82c1b4 commit e6231ad

File tree

13 files changed

+311
-66
lines changed

13 files changed

+311
-66
lines changed

src/atomate2/ase/schemas.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -247,10 +247,12 @@ def from_ase_task_doc(
247247
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
248248
"""
249249
task_document_kwargs.update(
250-
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys}
250+
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
251+
structure=ase_task_doc.mol_or_struct,
252+
)
253+
return cls.from_structure(
254+
meta_structure=ase_task_doc.mol_or_struct, **task_document_kwargs
251255
)
252-
task_document_kwargs["structure"] = ase_task_doc.mol_or_struct
253-
return cls(**task_document_kwargs)
254256

255257

256258
class AseMoleculeTaskDoc(MoleculeMetadata):

src/atomate2/forcefields/__init__.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from enum import Enum
67
from typing import TYPE_CHECKING
78

@@ -12,7 +13,10 @@
1213
class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
1314
"""Names of ML force fields."""
1415

15-
MACE = "MACE"
16+
MACE = "MACE" # This is MACE-MP-0 (medium), deprecated
17+
MACE_MP_0 = "MACE-MP-0"
18+
MACE_MPA_0 = "MACE-MPA-0"
19+
MACE_MP_0B3 = "MACE-MP-0b3"
1620
GAP = "GAP"
1721
M3GNet = "M3GNet"
1822
CHGNet = "CHGNet"
@@ -27,7 +31,7 @@ def _missing_(cls, value: Any) -> Any:
2731
if isinstance(value, str):
2832
value = value.split("MLFF.")[-1]
2933
for member in cls:
30-
if member.value == value:
34+
if member.name == value:
3135
return member
3236
return None
3337

@@ -45,7 +49,21 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
4549
-------
4650
str : the name of the forcefield from MLFF
4751
"""
48-
if isinstance(force_field_name, str) and force_field_name in MLFF.__members__:
52+
if isinstance(force_field_name, str):
4953
# ensure `force_field_name` uses enum format
50-
force_field_name = MLFF(force_field_name)
51-
return str(force_field_name)
54+
if force_field_name in MLFF.__members__:
55+
force_field_name = MLFF[force_field_name]
56+
elif force_field_name in [v.value for v in MLFF]:
57+
force_field_name = MLFF(force_field_name)
58+
force_field_name = str(force_field_name)
59+
if force_field_name in {"MLFF.MACE", "MACE"}:
60+
warnings.warn(
61+
"Because the default MP-trained MACE model is constantly evolving, "
62+
"we no longer recommend using `MACE` or `MLFF.MACE` to specify "
63+
"a MACE model. For reproducibility purposes, specifying `MACE` "
64+
"will still default to MACE-MP-0 (medium), which is identical to "
65+
"specifying `MLFF.MACE_MP_0`.",
66+
category=UserWarning,
67+
stacklevel=2,
68+
)
69+
return force_field_name

src/atomate2/forcefields/flows/elastic.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def from_force_field_name(
127127
**(mlff_kwargs or {}),
128128
"force_field_name": _get_formatted_ff_name(force_field_name),
129129
}
130-
return cls(
131-
name=f"{str(force_field_name).split('MLFF.')[-1]} elastic",
130+
kwargs.update(
132131
bulk_relax_maker=ForceFieldRelaxMaker(
133132
relax_cell=True,
134133
**default_kwargs,
@@ -137,5 +136,8 @@ def from_force_field_name(
137136
relax_cell=False,
138137
**default_kwargs,
139138
),
139+
)
140+
return cls(
141+
name=f"{str(force_field_name).split('MLFF.')[-1]} elastic",
140142
**kwargs,
141143
)

src/atomate2/forcefields/flows/eos.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,15 @@
88
from monty.dev import deprecated
99

1010
from atomate2.common.flows.eos import CommonEosMaker
11-
from atomate2.forcefields import MLFF, _get_formatted_ff_name
11+
from atomate2.forcefields import _get_formatted_ff_name
1212
from atomate2.forcefields.jobs import ForceFieldRelaxMaker
1313

1414
if TYPE_CHECKING:
1515
from jobflow import Maker
1616
from typing_extensions import Self
1717

18+
from atomate2.forcefields import MLFF
19+
1820

1921
@dataclass
2022
class ForceFieldEosMaker(CommonEosMaker):
@@ -74,25 +76,28 @@ def from_force_field_name(
7476
relax_initial_structure: bool = True
7577
Whether to relax the initial structure before performing an EOS fit.
7678
**kwargs
77-
Additional kwargs to pass to ElasticMaker
79+
Additional kwargs to pass to ForceFieldEosMaker
7880
7981
8082
Returns
8183
-------
8284
ForceFieldEosMaker
8385
"""
8486
force_field_name = _get_formatted_ff_name(force_field_name)
85-
return cls(
86-
name=f"{force_field_name.split('MLFF.')[-1]} EOS Maker",
87-
initial_relax_maker=(
88-
ForceFieldRelaxMaker(force_field_name=force_field_name)
89-
if relax_initial_structure
90-
else None
91-
),
87+
if relax_initial_structure:
88+
kwargs.update(
89+
initial_relax_maker=ForceFieldRelaxMaker(
90+
force_field_name=force_field_name
91+
)
92+
)
93+
kwargs.update(
9294
eos_relax_maker=ForceFieldRelaxMaker(
9395
force_field_name=force_field_name, relax_cell=False
9496
),
9597
static_maker=None,
98+
)
99+
return cls(
100+
name=f"{force_field_name.split('MLFF.')[-1]} EOS Maker",
96101
**kwargs,
97102
)
98103

@@ -202,12 +207,14 @@ class M3GNetEosMaker(CommonEosMaker):
202207
@deprecated(
203208
replacement=ForceFieldEosMaker,
204209
deadline=(2025, 1, 1),
205-
message='Use ForceFieldEosMaker.from_force_field_name(force_field_name = "MACE")',
210+
message=(
211+
'Use ForceFieldEosMaker.from_force_field_name(force_field_name = "MACE-MP-0")'
212+
),
206213
)
207214
@dataclass
208215
class MACEEosMaker(CommonEosMaker):
209216
"""
210-
Generate equation of state data using the MACE ML forcefield.
217+
Generate equation of state data using the MACE-MP-0 ML forcefield.
211218
212219
First relax a structure using relax_maker.
213220
Then perform a series of deformations on the relaxed structure, and
@@ -238,13 +245,13 @@ class MACEEosMaker(CommonEosMaker):
238245
TODO: remove this when clash is fixed
239246
"""
240247

241-
name: str = "MACE EOS Maker"
248+
name: str = "MACE-MP-0 EOS Maker"
242249
initial_relax_maker: Maker = field(
243-
default_factory=lambda: ForceFieldRelaxMaker(force_field_name="MACE")
250+
default_factory=lambda: ForceFieldRelaxMaker(force_field_name="MACE-MP-0")
244251
)
245252
eos_relax_maker: Maker = field(
246253
default_factory=lambda: ForceFieldRelaxMaker(
247-
force_field_name="MACE", relax_cell=False
254+
force_field_name="MACE-MP-0", relax_cell=False
248255
)
249256
)
250257
static_maker: Maker = None

src/atomate2/forcefields/flows/phonons.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,18 @@
33
from __future__ import annotations
44

55
from dataclasses import dataclass, field
6-
from typing import Literal
6+
from typing import TYPE_CHECKING, Literal
77

88
from atomate2 import SETTINGS
99
from atomate2.common.flows.phonons import BasePhononMaker
10+
from atomate2.forcefields import _get_formatted_ff_name
1011
from atomate2.forcefields.jobs import ForceFieldRelaxMaker, ForceFieldStaticMaker
1112

13+
if TYPE_CHECKING:
14+
from typing_extensions import Self
15+
16+
from atomate2.forcefields import MLFF
17+
1218

1319
@dataclass
1420
class PhononMaker(BasePhononMaker):
@@ -119,7 +125,7 @@ class PhononMaker(BasePhononMaker):
119125
use_symmetrized_structure: Literal["primitive", "conventional"] | None = None
120126
bulk_relax_maker: ForceFieldRelaxMaker | None = field(
121127
default_factory=lambda: ForceFieldRelaxMaker(
122-
force_field_name="CHGNet", relax_kwargs={"fmax": 0.00001}
128+
force_field_name="CHGNet", relax_kwargs={"fmax": 1e-5}
123129
)
124130
)
125131
static_energy_maker: ForceFieldStaticMaker | None = field(
@@ -146,3 +152,45 @@ def prev_calc_dir_argname(self) -> None:
146152
calculations are performed for each ordering (relax -> static)
147153
"""
148154
return
155+
156+
@classmethod
157+
def from_force_field_name(
158+
cls,
159+
force_field_name: str | MLFF,
160+
relax_initial_structure: bool = True,
161+
**kwargs,
162+
) -> Self:
163+
"""
164+
Create a phonon flow from a forcefield name.
165+
166+
Parameters
167+
----------
168+
force_field_name : str or .MLFF
169+
The name of the force field.
170+
relax_initial_structure: bool = True
171+
Whether to relax the initial structure before performing an EOS fit.
172+
**kwargs
173+
Additional kwargs to pass to PhononMaker
174+
175+
176+
Returns
177+
-------
178+
PhononMaker
179+
"""
180+
force_field_name = _get_formatted_ff_name(force_field_name)
181+
if relax_initial_structure:
182+
kwargs.update(
183+
bulk_relax_maker=ForceFieldRelaxMaker(
184+
force_field_name=force_field_name, relax_kwargs={"fmax": 1e-5}
185+
)
186+
)
187+
kwargs.update(
188+
static_energy_maker=ForceFieldStaticMaker(
189+
force_field_name=force_field_name
190+
),
191+
phonon_displacement_maker=ForceFieldStaticMaker(
192+
force_field_name=force_field_name
193+
),
194+
born_maker=None,
195+
)
196+
return cls(name=f"{force_field_name.split('MLFF.')[-1]} Phonon Maker", **kwargs)

0 commit comments

Comments
 (0)