Skip to content

Commit 5640ecf

Browse files
Alternative ASE/MLFF batch mode (#1456)
* cache ase calculator * batch mode extension in phonons + tests adapted from #1196 * remove abstractmethod to ensure backwards compat * add socket mode to EOS maker * eos batch fix
1 parent c2992c8 commit 5640ecf

13 files changed

Lines changed: 326 additions & 156 deletions

File tree

src/atomate2/ase/jobs.py

Lines changed: 77 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import logging
66
import time
7-
from abc import ABC, abstractmethod
7+
from abc import ABC
88
from dataclasses import dataclass, field
99
from typing import TYPE_CHECKING
1010

@@ -54,8 +54,7 @@ class AseMaker(Maker, ABC):
5454
class EMTStaticMaker(AseMaker):
5555
name: str = "EMT static maker"
5656
57-
@property
58-
def calculator(self):
57+
def _get_calculator(self):
5958
return EMT()
6059
```
6160
@@ -95,27 +94,44 @@ def calculator(self):
9594
store_trajectory: StoreTrajectoryOption = StoreTrajectoryOption.NO
9695
tags: list[str] | None = None
9796

97+
def __post_init__(self) -> None:
98+
"""Enable caching of the ASE calculator via private attribute."""
99+
self._calculator: Calculator | None = None
100+
98101
@job(data=_ASE_DATA_OBJECTS)
99102
def make(
100103
self,
101-
mol_or_struct: Molecule | Structure,
104+
mol_or_struct: Molecule | Structure | list[Molecule | Structure],
102105
prev_dir: str | Path | None = None,
103-
) -> AseStructureTaskDoc | AseMoleculeTaskDoc:
106+
) -> (
107+
AseStructureTaskDoc
108+
| AseMoleculeTaskDoc
109+
| list[AseStructureTaskDoc | AseMoleculeTaskDoc]
110+
):
104111
"""
105112
Run ASE as job, can be re-implemented in subclasses.
106113
107114
Parameters
108115
----------
109-
mol_or_struct: .Molecule or .Structure
110-
pymatgen molecule or structure
116+
mol_or_struct: .Molecule, .Structure, or a list thereof
117+
pymatgen molecule(s) or structure(s)
111118
prev_dir : str or Path or None
112119
A previous calculation directory to copy output files from. Unused, just
113120
added to match the method signature of other makers.
121+
122+
Returns
123+
-------
124+
AseStructureTaskDoc, AseMoleculeTaskDoc, or list thereof.
114125
"""
115-
return AseTaskDoc.to_mol_or_struct_metadata_doc(
116-
getattr(self.calculator, "name", type(self.calculator).__name__),
117-
self.run_ase(mol_or_struct, prev_dir=prev_dir),
118-
)
126+
batch_mode = isinstance(mol_or_struct, list)
127+
results = [
128+
AseTaskDoc.to_mol_or_struct_metadata_doc(
129+
getattr(self.calculator, "name", type(self.calculator).__name__),
130+
self.run_ase(atoms, prev_dir=prev_dir),
131+
)
132+
for atoms in (mol_or_struct if batch_mode else [mol_or_struct])
133+
]
134+
return results if batch_mode else results[0]
119135

120136
def run_ase(
121137
self,
@@ -148,11 +164,25 @@ def run_ase(
148164
elapsed_time=t_f - t_i,
149165
)
150166

167+
def _get_calculator(self) -> Calculator:
168+
"""Load ASE calculator, to be implemented by the user.
169+
170+
NB: To avoid breaking behavior, this method by default
171+
does nothing and *should not* be an `abstractmethod`.
172+
173+
Previously, users would define the `calculator` attr
174+
directly. That is still possible but will not benefit
175+
from caching the calculator.
176+
"""
177+
151178
@property
152-
@abstractmethod
153179
def calculator(self) -> Calculator:
154-
"""ASE calculator, method to be implemented in subclasses."""
155-
raise NotImplementedError
180+
"""Retrieve cached ASE calculator."""
181+
if getattr(self, "_calculator", None) is None:
182+
self._calculator = self._get_calculator()
183+
if self._calculator is None:
184+
raise ValueError("ASE calculator not properly initialized.")
185+
return self._calculator
156186

157187

158188
@dataclass
@@ -208,8 +238,7 @@ class AseRelaxMaker(AseMaker):
208238

209239
def __post_init__(self) -> None:
210240
"""Ensure that physical relaxation settings are used."""
211-
if hasattr(super(), "__post_init__"):
212-
super().__post_init__() # type: ignore[misc]
241+
super().__post_init__()
213242
if self.relax_cell and self.relax_shape:
214243
raise ValueError(
215244
"You have set both `relax_cell` (relaxing the cell shape and volume) "
@@ -220,38 +249,48 @@ def __post_init__(self) -> None:
220249
@job(data=_ASE_DATA_OBJECTS)
221250
def make(
222251
self,
223-
mol_or_struct: Molecule | Structure,
252+
mol_or_struct: Molecule | Structure | list[Molecule | Structure],
224253
prev_dir: str | Path | None = None,
225-
) -> AseStructureTaskDoc | AseMoleculeTaskDoc:
254+
) -> (
255+
AseStructureTaskDoc
256+
| AseMoleculeTaskDoc
257+
| list[AseStructureTaskDoc | AseMoleculeTaskDoc]
258+
):
226259
"""
227260
Relax a structure or molecule using ASE as a job.
228261
229262
Parameters
230263
----------
231-
mol_or_struct: .Molecule or .Structure
232-
pymatgen molecule or structure
264+
mol_or_struct: .Molecule or .Structure, or list thereof
265+
pymatgen molecule(s) or structure(s)
233266
prev_dir : str or Path or None
234267
A previous calculation directory to copy output files from. Unused, just
235268
added to match the method signature of other makers.
236269
237270
Returns
238271
-------
239-
AseStructureTaskDoc or AseMoleculeTaskDoc
272+
AseStructureTaskDoc or AseMoleculeTaskDoc, or list thereof
240273
"""
241-
return AseTaskDoc.to_mol_or_struct_metadata_doc(
242-
getattr(self.calculator, "name", type(self.calculator).__name__),
243-
self.run_ase(mol_or_struct, prev_dir=prev_dir),
244-
self.steps,
245-
relax_kwargs=self.relax_kwargs,
246-
optimizer_kwargs=self.optimizer_kwargs,
247-
relax_cell=self.relax_cell,
248-
relax_shape=self.relax_shape,
249-
fix_symmetry=self.fix_symmetry,
250-
symprec=self.symprec if self.fix_symmetry else None,
251-
ionic_step_data=self.ionic_step_data,
252-
store_trajectory=self.store_trajectory,
253-
tags=self.tags,
254-
)
274+
batch_mode = isinstance(mol_or_struct, list)
275+
276+
results = [
277+
AseTaskDoc.to_mol_or_struct_metadata_doc(
278+
getattr(self.calculator, "name", type(self.calculator).__name__),
279+
self.run_ase(atoms, prev_dir=prev_dir),
280+
self.steps,
281+
relax_kwargs=self.relax_kwargs,
282+
optimizer_kwargs=self.optimizer_kwargs,
283+
relax_cell=self.relax_cell,
284+
relax_shape=self.relax_shape,
285+
fix_symmetry=self.fix_symmetry,
286+
symprec=self.symprec if self.fix_symmetry else None,
287+
ionic_step_data=self.ionic_step_data,
288+
store_trajectory=self.store_trajectory,
289+
tags=self.tags,
290+
)
291+
for atoms in (mol_or_struct if batch_mode else [mol_or_struct])
292+
]
293+
return results if batch_mode else results[0]
255294

256295
def run_ase(
257296
self,
@@ -299,8 +338,7 @@ class EmtRelaxMaker(AseRelaxMaker):
299338

300339
name: str = "EMT relaxation"
301340

302-
@property
303-
def calculator(self) -> Calculator:
341+
def _get_calculator(self) -> Calculator:
304342
"""EMT calculator."""
305343
from ase.calculators.emt import EMT
306344

@@ -320,8 +358,7 @@ class LennardJonesRelaxMaker(AseRelaxMaker):
320358

321359
name: str = "Lennard-Jones 6-12 relaxation"
322360

323-
@property
324-
def calculator(self) -> Calculator:
361+
def _get_calculator(self) -> None:
325362
"""Lennard-Jones calculator."""
326363
from ase.calculators.lj import LennardJones
327364

@@ -378,8 +415,7 @@ class GFNxTBRelaxMaker(AseRelaxMaker):
378415
}
379416
)
380417

381-
@property
382-
def calculator(self) -> Calculator:
418+
def _get_calculator(self) -> None:
383419
"""GFN-xTB / TBLite calculator."""
384420
try:
385421
from tblite.ase import TBLite

src/atomate2/ase/md.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import os
99
import sys
1010
import time
11-
from abc import ABC, abstractmethod
11+
from abc import ABC
1212
from collections.abc import Sequence
1313
from dataclasses import dataclass, field
1414
from enum import Enum
@@ -189,6 +189,7 @@ class AseMDMaker(AseMaker, ABC):
189189

190190
def __post_init__(self) -> None:
191191
"""Ensure that ensemble is an enum."""
192+
super().__post_init__()
192193
if isinstance(self.ensemble, str):
193194
self.ensemble = MDEnsemble(self.ensemble.split("MDEnsemble.")[-1])
194195

@@ -444,12 +445,6 @@ def _callback(dyn: MolecularDynamics = md_runner) -> None:
444445
elapsed_time=t_f - t_i,
445446
)
446447

447-
@property
448-
@abstractmethod
449-
def calculator(self) -> Calculator:
450-
"""ASE calculator, to be overwritten by user."""
451-
raise NotImplementedError
452-
453448

454449
@dataclass
455450
class LennardJonesMDMaker(AseMDMaker):
@@ -461,8 +456,7 @@ class LennardJonesMDMaker(AseMDMaker):
461456

462457
name: str = "Lennard-Jones 6-12 MD"
463458

464-
@property
465-
def calculator(self) -> Calculator:
459+
def _get_calculator(self) -> Calculator:
466460
"""Lennard-Jones calculator."""
467461
from ase.calculators.lj import LennardJones
468462

@@ -495,8 +489,7 @@ class GFNxTBMDMaker(AseMDMaker):
495489
}
496490
)
497491

498-
@property
499-
def calculator(self) -> Calculator:
492+
def _get_calculator(self) -> Calculator:
500493
"""GFN-xTB / TBLite calculator."""
501494
try:
502495
from tblite.ase import TBLite

src/atomate2/ase/neb.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,7 @@ class EmtNebFromImagesMaker(AseNebFromImagesMaker):
257257

258258
name: str = "EMT NEB from images maker"
259259

260-
@property
261-
def calculator(self) -> Calculator:
260+
def _get_calculator(self) -> Calculator:
262261
"""EMT calculator."""
263262
from ase.calculators.emt import EMT
264263

0 commit comments

Comments
 (0)