Skip to content

Commit fee3e62

Browse files
committed
add _terminate_process() on class Custodian
- Refactor _do_check() to use _terminate_process() instead of passing terminate_func - DRY: centralizes process termination logic in one place - uses multi-node compatible approach from PR #396
1 parent 7fc3800 commit fee3e62

File tree

1 file changed

+43
-42
lines changed

1 file changed

+43
-42
lines changed

src/custodian/custodian.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -456,52 +456,47 @@ def _run_job(self, job_n, job) -> None:
456456
job.setup(self.directory)
457457

458458
attempt = 0
459-
p = None # Initialize p to None in case the while loop never executes
459+
process = None # Initialize p to None in case the while loop never executes
460460
while self.total_errors < self.max_errors and self.errors_current_job < self.max_errors_per_job:
461461
attempt += 1
462462
logger.info(
463463
f"Starting job no. {job_n} ({job.name}) attempt no. {attempt}. Total errors and "
464464
f"errors in job thus far = {self.total_errors}, {self.errors_current_job}."
465465
)
466466

467-
p = job.run(directory=self.directory)
467+
process = job.run(directory=self.directory)
468468
# Check for errors using the error handlers and perform
469469
# corrections.
470470
has_error = False
471471
zero_return_code = True
472472

473-
# Choose the terminate function to run. If a terminate_func exists, this
474-
# should take priority, followed by Job.terminate if implemented, and finally
475-
# subprocess.Popen.terminate if neither of the former exist.
476-
terminate = self.terminate_func or job.terminate or p.terminate
477-
478473
# While the job is running, we use the handlers that are
479474
# monitors to monitor the job.
480-
if isinstance(p, subprocess.Popen):
475+
if isinstance(process, subprocess.Popen):
481476
if self.monitors:
482-
n = 0
477+
poll_idx = 0
483478
while True:
484-
n += 1
479+
poll_idx += 1
485480
time.sleep(self.polling_time_step)
486481
# We poll the process p to check if it is still running.
487482
# Note that the process here is not the actual calculation
488483
# but whatever is used to control the execution of the
489484
# calculation executable. For instance; mpirun, srun, and so on.
490-
if p.poll() is not None:
485+
if process.poll() is not None:
491486
break
492-
if n % self.monitor_freq == 0:
487+
if poll_idx % self.monitor_freq == 0:
493488
# At every self.polling_time_step * self.monitor_freq seconds,
494489
# we check the job for errors using handlers that are monitors.
495490
# In order to properly kill a running calculation, we use
496491
# the appropriate implementation of terminate.
497-
has_error = self._do_check(self.monitors, terminate)
492+
has_error = self._do_check(self.monitors, process, job)
498493
else:
499-
p.wait()
500-
if self.terminate_func is not None and self.terminate_func != p.terminate:
494+
process.wait()
495+
if self.terminate_func is not None and self.terminate_func != process.terminate:
501496
self.terminate_func()
502497
time.sleep(self.polling_time_step)
503498

504-
zero_return_code = p.returncode == 0
499+
zero_return_code = process.returncode == 0
505500

506501
logger.info(f"{job.name}.run has completed. Checking remaining handlers")
507502
# Check for errors again, since in some cases non-monitor
@@ -523,7 +518,7 @@ def _run_job(self, job_n, job) -> None:
523518
if not zero_return_code:
524519
if self.terminate_on_nonzero_returncode:
525520
self.run_log[-1]["nonzero_return_code"] = True
526-
msg = f"Job return code is {p.returncode}. Terminating..."
521+
msg = f"Job return code is {process.returncode}. Terminating..."
527522
logger.info(msg)
528523
raise ReturnCodeError(msg, raises=True)
529524
warnings.warn("subprocess returned a non-zero return code. Check outputs carefully...")
@@ -543,25 +538,9 @@ def _run_job(self, job_n, job) -> None:
543538
raise NonRecoverableError(msg, raises=False, handler=corr["handler"])
544539

545540
# Terminate any running process before raising max errors exceptions
546-
if isinstance(p, subprocess.Popen) and p.poll() is None:
547-
logger.warning("Max errors threshold reached. Terminating running process.")
548-
terminate = self.terminate_func or job.terminate or p.terminate
549-
try:
550-
# Call terminate with directory parameter if it's not the default Popen terminate
551-
if terminate != p.terminate:
552-
terminate(directory=self.directory)
553-
else:
554-
terminate()
555-
# Wait briefly for process to terminate
556-
if hasattr(p, "wait"):
557-
try:
558-
p.wait(timeout=10)
559-
except subprocess.TimeoutExpired:
560-
logger.warning("Process did not terminate gracefully, force killing")
561-
p.kill()
562-
p.wait()
563-
except Exception:
564-
logger.exception("Error terminating process")
541+
if process is not None:
542+
logger.warning("Max errors threshold reached.")
543+
self._terminate_process(process, job)
565544

566545
if self.errors_current_job >= self.max_errors_per_job:
567546
self.run_log[-1]["max_errors_per_job"] = True
@@ -675,8 +654,30 @@ def run_interrupted(self):
675654
gzip_dir(self.directory)
676655
return None
677656

678-
def _do_check(self, handlers, terminate_func=None):
679-
"""Checks the specified handlers. Returns True iff errors caught."""
657+
def _terminate_process(self, process, job) -> None:
658+
"""Terminate a running subprocess using the job's terminate method or fallback."""
659+
if not isinstance(process, subprocess.Popen) or process.poll() is not None:
660+
return # Not a process or already finished
661+
662+
logger.warning("Terminating running process.")
663+
terminate = self.terminate_func or job.terminate or process.terminate
664+
665+
try:
666+
if terminate != process.terminate:
667+
terminate(directory=self.directory)
668+
else:
669+
terminate()
670+
try:
671+
process.wait(timeout=10)
672+
except subprocess.TimeoutExpired:
673+
logger.warning("Process did not terminate gracefully, force killing")
674+
process.kill()
675+
process.wait()
676+
except Exception:
677+
logger.exception("Error terminating process")
678+
679+
def _do_check(self, handlers, process=None, job=None):
680+
"""Check handlers and return True if errors were caught."""
680681
corrections = []
681682
for handler in handlers:
682683
try:
@@ -694,11 +695,11 @@ def _do_check(self, handlers, terminate_func=None):
694695
)
695696
logger.warning(f"{msg} Correction not applied.")
696697
continue
697-
if terminate_func is not None and handler.is_terminating:
698+
if process is not None and job is not None and handler.is_terminating:
698699
logger.info("Terminating job")
699-
terminate_func(directory=self.directory)
700-
# make sure we don't terminate twice
701-
terminate_func = None
700+
self._terminate_process(process, job)
701+
# Make sure we don't terminate twice
702+
process = None
702703
dct = handler.correct(directory=self.directory)
703704
logger.error(type(handler).__name__, extra=dct)
704705
dct["handler"] = handler

0 commit comments

Comments
 (0)