Skip to content

Commit 97014ed

Browse files
fix trajectory observer and file naming
1 parent 34f46e7 commit 97014ed

File tree

2 files changed

+26
-13
lines changed

2 files changed

+26
-13
lines changed

src/atomate2/ase/utils.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import sys
99
import time
1010
from copy import deepcopy
11+
from pathlib import Path
1112
from typing import TYPE_CHECKING
1213

1314
import numpy as np
@@ -456,7 +457,7 @@ def run_neb(
456457
images: list[Atoms | Structure | Molecule],
457458
fmax: float = 0.1,
458459
steps: int = 500,
459-
traj_file: str = None,
460+
traj_file: str | Path | list[str | Path] = None,
460461
traj_file_fmt: Literal["pmg", "ase", "xdatcar"] = "ase",
461462
interval: int = 1,
462463
verbose: bool = False,
@@ -474,8 +475,16 @@ def run_neb(
474475
Total force tolerance for relaxation convergence.
475476
steps : int
476477
Max number of steps for relaxation.
477-
traj_file : str
478-
The trajectory file for saving.
478+
traj_file : str, Path, or a list of str / Path
479+
The trajectory file for saving. If a single str or Path,
480+
this specifies the file name prefix. For example,
481+
`traj_file = "traj_mp-149.json.gz"`
482+
will yield individual trajectory file names:
483+
traj_mp-149-image-1.json.gz
484+
traj_mp-149-image-2.json.gz
485+
...
486+
Alternately, if this is a list of str / Path, this specifies the
487+
file name for each image.
479488
interval : int
480489
The step interval for saving the trajectories.
481490
verbose : bool
@@ -506,21 +515,26 @@ def run_neb(
506515
observers = [TrajectoryObserver(image) for image in images]
507516
optimizer = self.opt_class(neb_calc, **kwargs)
508517
for idx in range(num_images):
509-
optimizer.attach(observers[idx], interval=interval, atoms=images[idx])
518+
optimizer.attach(observers[idx], interval=interval)
510519
t_i = time.perf_counter()
511520
optimizer.run(fmax=fmax, steps=steps)
512521
t_f = time.perf_counter()
513522
[observers[idx]() for idx in range(num_images)]
514523

515524
if traj_file is not None:
516-
for idx in range(num_images):
517-
traj_file_split = traj_file.split(".")
518-
traj_file_prefix = ".".join(traj_file_split[:-1])
519-
traj_file_ext = traj_file[-1]
520-
observers[idx].save(
521-
f"{traj_file_prefix}-image-{idx + 1}.{traj_file_ext}",
522-
fmt=traj_file_fmt,
523-
)
525+
if isinstance(traj_file, str | Path):
526+
traj_file = Path(traj_file)
527+
traj_file_suffix = "".join(traj_file.suffixes)
528+
traj_file_prefix = str(traj_file).split(traj_file_suffix)[0]
529+
traj_files = [
530+
f"{traj_file_prefix}-image-{idx + 1}{traj_file_suffix}"
531+
for idx in range(num_images)
532+
]
533+
elif isinstance(traj_file, list | tuple):
534+
traj_files = [str(f) for f in traj_file]
535+
536+
for idx, f in enumerate(traj_files):
537+
observers[idx].save(f, fmt=traj_file_fmt)
524538

525539
images = [
526540
self.ase_adaptor.get_structure(image, cls=Molecule if is_mol else Structure)

src/atomate2/forcefields/neb.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ class ForceFieldNebMaker(AseNebMaker):
3232

3333
def __post_init__(self) -> None:
3434
"""Ensure that force_field_name is correctly assigned."""
35-
super().__post_init__()
3635
self.force_field_name = _get_formatted_ff_name(self.force_field_name)
3736

3837
# Pad calculator_kwargs with default values, but permit user to override them

0 commit comments

Comments
 (0)