88import sys
99import time
1010from copy import deepcopy
11+ from pathlib import Path
1112from typing import TYPE_CHECKING
1213
1314import numpy as np
@@ -456,7 +457,7 @@ def run_neb(
456457 images : list [Atoms | Structure | Molecule ],
457458 fmax : float = 0.1 ,
458459 steps : int = 500 ,
459- traj_file : str = None ,
460+ traj_file : str | Path | list [ str | Path ] = None ,
460461 traj_file_fmt : Literal ["pmg" , "ase" , "xdatcar" ] = "ase" ,
461462 interval : int = 1 ,
462463 verbose : bool = False ,
@@ -474,8 +475,16 @@ def run_neb(
474475 Total force tolerance for relaxation convergence.
475476 steps : int
476477 Max number of steps for relaxation.
477- traj_file : str
478- The trajectory file for saving.
478+ traj_file : str, Path, or a list of str / Path
479+ The trajectory file for saving. If a single str or Path,
480+ this specifies the file name prefix. For example,
481+ `traj_file = "traj_mp-149.json.gz"`
482+ will yield individual trajectory file names:
483+ traj_mp-149-image-1.json.gz
484+ traj_mp-149-image-2.json.gz
485+ ...
486+ Alternately, if this is a list of str / Path, this specifies the
487+ file name for each image.
479488 interval : int
480489 The step interval for saving the trajectories.
481490 verbose : bool
@@ -506,21 +515,26 @@ def run_neb(
506515 observers = [TrajectoryObserver (image ) for image in images ]
507516 optimizer = self .opt_class (neb_calc , ** kwargs )
508517 for idx in range (num_images ):
509- optimizer .attach (observers [idx ], interval = interval , atoms = images [ idx ] )
518+ optimizer .attach (observers [idx ], interval = interval )
510519 t_i = time .perf_counter ()
511520 optimizer .run (fmax = fmax , steps = steps )
512521 t_f = time .perf_counter ()
513522 [observers [idx ]() for idx in range (num_images )]
514523
515524 if traj_file is not None :
516- for idx in range (num_images ):
517- traj_file_split = traj_file .split ("." )
518- traj_file_prefix = "." .join (traj_file_split [:- 1 ])
519- traj_file_ext = traj_file [- 1 ]
520- observers [idx ].save (
521- f"{ traj_file_prefix } -image-{ idx + 1 } .{ traj_file_ext } " ,
522- fmt = traj_file_fmt ,
523- )
525+ if isinstance (traj_file , str | Path ):
526+ traj_file = Path (traj_file )
527+ traj_file_suffix = "" .join (traj_file .suffixes )
528+ traj_file_prefix = str (traj_file ).split (traj_file_suffix )[0 ]
529+ traj_files = [
530+ f"{ traj_file_prefix } -image-{ idx + 1 } { traj_file_suffix } "
531+ for idx in range (num_images )
532+ ]
533+ elif isinstance (traj_file , list | tuple ):
534+ traj_files = [str (f ) for f in traj_file ]
535+
536+ for idx , f in enumerate (traj_files ):
537+ observers [idx ].save (f , fmt = traj_file_fmt )
524538
525539 images = [
526540 self .ase_adaptor .get_structure (image , cls = Molecule if is_mol else Structure )
0 commit comments