Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin MACE calculator version, add missing metadata to AseStructureTaskDoc #1119

Merged
merged 8 commits into from
Feb 11, 2025
8 changes: 5 additions & 3 deletions src/atomate2/ase/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,10 +247,12 @@ def from_ase_task_doc(
Additional keyword args passed to :obj:`.AseStructureTaskDoc()`.
"""
task_document_kwargs.update(
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys}
{k: getattr(ase_task_doc, k) for k in _task_doc_translation_keys},
structure=ase_task_doc.mol_or_struct,
)
return cls.from_structure(
meta_structure=ase_task_doc.mol_or_struct, **task_document_kwargs
)
task_document_kwargs["structure"] = ase_task_doc.mol_or_struct
return cls(**task_document_kwargs)


class AseMoleculeTaskDoc(MoleculeMetadata):
Expand Down
28 changes: 23 additions & 5 deletions src/atomate2/forcefields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from enum import Enum
from typing import TYPE_CHECKING

Expand All @@ -12,7 +13,10 @@
class MLFF(Enum): # TODO inherit from StrEnum when 3.11+
"""Names of ML force fields."""

MACE = "MACE"
MACE = "MACE" # This is MACE-MP-0 (medium), deprecated
MACE_MP_0 = "MACE-MP-0"
MACE_MPA_0 = "MACE-MPA-0"
MACE_MP_0B3 = "MACE-MP-0b3"
GAP = "GAP"
M3GNet = "M3GNet"
CHGNet = "CHGNet"
Expand All @@ -27,7 +31,7 @@ def _missing_(cls, value: Any) -> Any:
if isinstance(value, str):
value = value.split("MLFF.")[-1]
for member in cls:
if member.value == value:
if member.name == value:
return member
return None

Expand All @@ -45,7 +49,21 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
-------
str : the name of the forcefield from MLFF
"""
if isinstance(force_field_name, str) and force_field_name in MLFF.__members__:
if isinstance(force_field_name, str):
# ensure `force_field_name` uses enum format
force_field_name = MLFF(force_field_name)
return str(force_field_name)
if force_field_name in MLFF.__members__:
force_field_name = MLFF[force_field_name]
elif force_field_name in [v.value for v in MLFF]:
force_field_name = MLFF(force_field_name)
force_field_name = str(force_field_name)
if force_field_name in {"MLFF.MACE", "MACE"}:
warnings.warn(
"Because the default MP-trained MACE model is constantly evolving, "
"we no longer recommend using `MACE` or `MLFF.MACE` to specify "
"a MACE model. For reproducibility purposes, specifying `MACE` "
"will still default to MACE-MP-0 (medium), which is identical to "
"specifying `MLFF.MACE_MP_0`.",
category=UserWarning,
stacklevel=2,
)
return force_field_name
6 changes: 4 additions & 2 deletions src/atomate2/forcefields/flows/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,7 @@ def from_force_field_name(
**(mlff_kwargs or {}),
"force_field_name": _get_formatted_ff_name(force_field_name),
}
return cls(
name=f"{str(force_field_name).split('MLFF.')[-1]} elastic",
kwargs.update(
bulk_relax_maker=ForceFieldRelaxMaker(
relax_cell=True,
**default_kwargs,
Expand All @@ -137,5 +136,8 @@ def from_force_field_name(
relax_cell=False,
**default_kwargs,
),
)
return cls(
name=f"{str(force_field_name).split('MLFF.')[-1]} elastic",
**kwargs,
)
35 changes: 21 additions & 14 deletions src/atomate2/forcefields/flows/eos.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from monty.dev import deprecated

from atomate2.common.flows.eos import CommonEosMaker
from atomate2.forcefields import MLFF, _get_formatted_ff_name
from atomate2.forcefields import _get_formatted_ff_name
from atomate2.forcefields.jobs import ForceFieldRelaxMaker

if TYPE_CHECKING:
from jobflow import Maker
from typing_extensions import Self

from atomate2.forcefields import MLFF


@dataclass
class ForceFieldEosMaker(CommonEosMaker):
Expand Down Expand Up @@ -74,25 +76,28 @@ def from_force_field_name(
relax_initial_structure: bool = True
Whether to relax the initial structure before performing an EOS fit.
**kwargs
Additional kwargs to pass to ElasticMaker
Additional kwargs to pass to ForceFieldEosMaker


Returns
-------
ForceFieldEosMaker
"""
force_field_name = _get_formatted_ff_name(force_field_name)
return cls(
name=f"{force_field_name.split('MLFF.')[-1]} EOS Maker",
initial_relax_maker=(
ForceFieldRelaxMaker(force_field_name=force_field_name)
if relax_initial_structure
else None
),
if relax_initial_structure:
kwargs.update(
initial_relax_maker=ForceFieldRelaxMaker(
force_field_name=force_field_name
)
)
kwargs.update(
eos_relax_maker=ForceFieldRelaxMaker(
force_field_name=force_field_name, relax_cell=False
),
static_maker=None,
)
return cls(
name=f"{force_field_name.split('MLFF.')[-1]} EOS Maker",
**kwargs,
)

Expand Down Expand Up @@ -202,12 +207,14 @@ class M3GNetEosMaker(CommonEosMaker):
@deprecated(
replacement=ForceFieldEosMaker,
deadline=(2025, 1, 1),
message='Use ForceFieldEosMaker.from_force_field_name(force_field_name = "MACE")',
message=(
'Use ForceFieldEosMaker.from_force_field_name(force_field_name = "MACE-MP-0")'
),
)
@dataclass
class MACEEosMaker(CommonEosMaker):
"""
Generate equation of state data using the MACE ML forcefield.
Generate equation of state data using the MACE-MP-0 ML forcefield.

First relax a structure using relax_maker.
Then perform a series of deformations on the relaxed structure, and
Expand Down Expand Up @@ -238,13 +245,13 @@ class MACEEosMaker(CommonEosMaker):
TODO: remove this when clash is fixed
"""

name: str = "MACE EOS Maker"
name: str = "MACE-MP-0 EOS Maker"
initial_relax_maker: Maker = field(
default_factory=lambda: ForceFieldRelaxMaker(force_field_name="MACE")
default_factory=lambda: ForceFieldRelaxMaker(force_field_name="MACE-MP-0")
)
eos_relax_maker: Maker = field(
default_factory=lambda: ForceFieldRelaxMaker(
force_field_name="MACE", relax_cell=False
force_field_name="MACE-MP-0", relax_cell=False
)
)
static_maker: Maker = None
52 changes: 50 additions & 2 deletions src/atomate2/forcefields/flows/phonons.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,18 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal
from typing import TYPE_CHECKING, Literal

from atomate2 import SETTINGS
from atomate2.common.flows.phonons import BasePhononMaker
from atomate2.forcefields import _get_formatted_ff_name
from atomate2.forcefields.jobs import ForceFieldRelaxMaker, ForceFieldStaticMaker

if TYPE_CHECKING:
from typing_extensions import Self

from atomate2.forcefields import MLFF


@dataclass
class PhononMaker(BasePhononMaker):
Expand Down Expand Up @@ -119,7 +125,7 @@ class PhononMaker(BasePhononMaker):
use_symmetrized_structure: Literal["primitive", "conventional"] | None = None
bulk_relax_maker: ForceFieldRelaxMaker | None = field(
default_factory=lambda: ForceFieldRelaxMaker(
force_field_name="CHGNet", relax_kwargs={"fmax": 0.00001}
force_field_name="CHGNet", relax_kwargs={"fmax": 1e-5}
)
)
static_energy_maker: ForceFieldStaticMaker | None = field(
Expand All @@ -146,3 +152,45 @@ def prev_calc_dir_argname(self) -> None:
calculations are performed for each ordering (relax -> static)
"""
return

@classmethod
def from_force_field_name(
cls,
force_field_name: str | MLFF,
relax_initial_structure: bool = True,
**kwargs,
) -> Self:
"""
Create a phonon flow from a forcefield name.

Parameters
----------
force_field_name : str or .MLFF
The name of the force field.
relax_initial_structure: bool = True
Whether to relax the initial structure before performing an EOS fit.
**kwargs
Additional kwargs to pass to PhononMaker


Returns
-------
PhononMaker
"""
force_field_name = _get_formatted_ff_name(force_field_name)
if relax_initial_structure:
kwargs.update(
bulk_relax_maker=ForceFieldRelaxMaker(
force_field_name=force_field_name, relax_kwargs={"fmax": 1e-5}
)
)
kwargs.update(
static_energy_maker=ForceFieldStaticMaker(
force_field_name=force_field_name
),
phonon_displacement_maker=ForceFieldStaticMaker(
force_field_name=force_field_name
),
born_maker=None,
)
return cls(name=f"{force_field_name.split('MLFF.')[-1]} Phonon Maker", **kwargs)
Loading
Loading