Skip to content

Commit 796ae22

Browse files
precommit
1 parent b6ee0bb commit 796ae22

File tree

6 files changed

+28
-15
lines changed

6 files changed

+28
-15
lines changed

src/atomate2/forcefields/__init__.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22

33
from __future__ import annotations
44

5+
import warnings
56
from enum import Enum
67
from typing import TYPE_CHECKING
7-
import warnings
88

99
if TYPE_CHECKING:
1010
from typing import Any
@@ -56,13 +56,14 @@ def _get_formatted_ff_name(force_field_name: str | MLFF) -> str:
5656
elif force_field_name in [v.value for v in MLFF]:
5757
force_field_name = MLFF(force_field_name)
5858
force_field_name = str(force_field_name)
59-
if force_field_name in {"MLFF.MACE","MACE"}:
59+
if force_field_name in {"MLFF.MACE", "MACE"}:
6060
warnings.warn(
6161
"Because the default MP-trained MACE model is constantly evolving, "
6262
"we no longer recommend using `MACE` or `MLFF.MACE` to specify "
6363
"a MACE model. For reproducibility purposes, specifying `MACE` "
6464
"will still default to MACE-MP-0 (medium), which is identical to "
6565
"specifying `MLFF.MACE_MP_0`.",
6666
category=UserWarning,
67+
stacklevel=2,
6768
)
68-
return force_field_name
69+
return force_field_name

src/atomate2/forcefields/flows/eos.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ class M3GNetEosMaker(CommonEosMaker):
207207
@deprecated(
208208
replacement=ForceFieldEosMaker,
209209
deadline=(2025, 1, 1),
210-
message='Use ForceFieldEosMaker.from_force_field_name(force_field_name = "MACE-MP-0")',
210+
message=(
211+
'Use ForceFieldEosMaker.from_force_field_name(force_field_name = "MACE-MP-0")'
212+
),
211213
)
212214
@dataclass
213215
class MACEEosMaker(CommonEosMaker):

src/atomate2/forcefields/jobs.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,10 @@ class NequipStaticMaker(ForceFieldStaticMaker):
540540
@deprecated(
541541
replacement=ForceFieldRelaxMaker,
542542
deadline=(2025, 1, 1),
543-
message="To use MACE-MP-0, set `force_field_name = 'MACE-MP-0'` in ForceFieldRelaxMaker.",
543+
message=(
544+
"To use MACE-MP-0, set `force_field_name = 'MACE-MP-0'` "
545+
"in ForceFieldRelaxMaker."
546+
),
544547
)
545548
@dataclass
546549
class MACERelaxMaker(ForceFieldRelaxMaker):
@@ -590,7 +593,10 @@ class MACERelaxMaker(ForceFieldRelaxMaker):
590593
@deprecated(
591594
replacement=ForceFieldStaticMaker,
592595
deadline=(2025, 1, 1),
593-
message="To use MACE-MP-0, set `force_field_name = 'MACE_MP_0'` in ForceFieldStaticMaker.",
596+
message=(
597+
"To use MACE-MP-0, set `force_field_name = 'MACE_MP_0'` "
598+
"in ForceFieldStaticMaker."
599+
),
594600
)
595601
@dataclass
596602
class MACEStaticMaker(ForceFieldStaticMaker):

src/atomate2/forcefields/md.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ class NEPMDMaker(ForceFieldMDMaker):
199199
@deprecated(
200200
replacement=ForceFieldMDMaker,
201201
deadline=(2025, 1, 1),
202-
message="To use MACE-MP-0, set `force_field_name = 'MACE_MP_0'` in ForceFieldMDMaker.",
202+
message=(
203+
"To use MACE-MP-0, set `force_field_name = 'MACE_MP_0'` in ForceFieldMDMaker."
204+
),
203205
)
204206
@dataclass
205207
class MACEMDMaker(ForceFieldMDMaker):

src/atomate2/forcefields/utils.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def ase_calculator(calculator_meta: str | dict, **kwargs: Any) -> Calculator | N
5858
potential = matgl.load_model(path)
5959
calculator = PESCalculator(potential, **kwargs)
6060

61-
elif calculator_name in map(MLFF, ("MACE","MACE-MP-0", "MACE-MPA-0", "MACE-MP-0b3")):
61+
elif calculator_name in map(
62+
MLFF, ("MACE", "MACE-MP-0", "MACE-MPA-0", "MACE-MP-0b3")
63+
):
6264
from mace.calculators import MACECalculator, mace_mp
6365

6466
model = kwargs.get("model")

tests/forcefields/test_md.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Tests for forcefield MD flows."""
22

3-
from contextlib import nullcontext
43
import sys
4+
from contextlib import nullcontext
55
from pathlib import Path
66

77
import numpy as np
@@ -43,22 +43,22 @@ def test_maker_initialization():
4343
for mlff in MLFF.__members__:
4444
context_mgr = nullcontext()
4545
if mlff == "MACE":
46-
context_mgr = pytest.warns(UserWarning, match = "default MP-trained MACE")
47-
46+
context_mgr = pytest.warns(UserWarning, match="default MP-trained MACE")
47+
4848
with context_mgr:
4949
assert ForceFieldMDMaker(force_field_name=MLFF(mlff)) == ForceFieldMDMaker(
5050
force_field_name=mlff
5151
)
52-
assert ForceFieldMDMaker(force_field_name=str(MLFF(mlff))) == ForceFieldMDMaker(
53-
force_field_name=mlff
54-
)
52+
assert ForceFieldMDMaker(
53+
force_field_name=str(MLFF(mlff))
54+
) == ForceFieldMDMaker(force_field_name=mlff)
5555

5656

5757
@pytest.mark.parametrize("ff_name", MLFF)
5858
def test_ml_ff_md_maker(
5959
ff_name, si_structure, sr_ti_o3_structure, al2_au_structure, test_dir, clean_dir
6060
):
61-
if ff_name in map(MLFF, ("Forcefield","MACE")):
61+
if ff_name in map(MLFF, ("Forcefield", "MACE")):
6262
return # nothing to test here, MLFF.Forcefield is just a generic placeholder
6363
if ff_name == MLFF.GAP and sys.version_info >= (3, 12):
6464
pytest.skip(

0 commit comments

Comments
 (0)