Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't pass class variable #3772

Merged
merged 1 commit into from
Feb 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 23 additions & 43 deletions parsl/executors/high_throughput/process_worker_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def __init__(self, *,

self._mp_manager = SpawnContext.Manager() # Starts a server process
self._tasks_in_progress = self._mp_manager.dict()
self._kill_event = threading.Event()
self._stop_event = threading.Event() # when set, will begin shutdown process

self.monitoring_queue = self._mp_manager.Queue()
self.pending_task_queue = SpawnContext.Queue()
Expand Down Expand Up @@ -298,14 +298,9 @@ def drain_to_incoming(self):
logger.debug("Sent drain")

@wrap_with_logs
def pull_tasks(self, kill_event):
def pull_tasks(self):
""" Pull tasks from the incoming tasks zmq pipe onto the internal
pending task queue

Parameters:
-----------
kill_event : threading.Event
Event to let the thread know when it is time to die.
"""
logger.info("starting")
poller = zmq.Poller()
Expand All @@ -319,7 +314,7 @@ def pull_tasks(self, kill_event):
last_interchange_contact = time.time()
task_recv_counter = 0

while not kill_event.is_set():
while not self._stop_event.is_set():

# This loop will sit inside poller.poll until either a message
# arrives or one of these event times is reached. This code
Expand Down Expand Up @@ -367,7 +362,7 @@ def pull_tasks(self, kill_event):
logger.debug("Got heartbeat from interchange")
elif tasks == DRAINED_CODE:
logger.info("Got fully drained message from interchange - setting kill flag")
kill_event.set()
self._stop_event.set()
else:
task_recv_counter += len(tasks)
logger.debug("Got executor tasks: {}, cumulative count of tasks: {}".format(
Expand All @@ -383,20 +378,14 @@ def pull_tasks(self, kill_event):
# Only check if no messages were received.
if time.time() >= last_interchange_contact + self.heartbeat_threshold:
logger.critical("Missing contact with interchange beyond heartbeat_threshold")
kill_event.set()
self._stop_event.set()
logger.critical("Exiting")
break

@wrap_with_logs
def push_results(self, kill_event):
def push_results(self):
""" Listens on the pending_result_queue and sends out results via zmq

Parameters:
-----------
kill_event : threading.Event
Event to let the thread know when it is time to die.
"""

logger.debug("Starting result push thread")

push_poll_period = max(10, self.poll_period) / 1000 # push_poll_period must be atleast 10 ms
Expand All @@ -406,7 +395,7 @@ def push_results(self, kill_event):
last_result_beat = time.time()
items = []

while not kill_event.is_set():
while not self._stop_event.is_set():
try:
logger.debug("Starting pending_result_queue get")
r = self.task_scheduler.get_result(block=True, timeout=push_poll_period)
Expand Down Expand Up @@ -438,18 +427,11 @@ def push_results(self, kill_event):
logger.critical("Exiting")

@wrap_with_logs
def worker_watchdog(self, kill_event: threading.Event):
"""Keeps workers alive.

Parameters:
-----------
kill_event : threading.Event
Event to let the thread know when it is time to die.
"""

def worker_watchdog(self):
"""Keeps workers alive."""
logger.debug("Starting worker watchdog")

while not kill_event.wait(self.heartbeat_period):
while not self._stop_event.wait(self.heartbeat_period):
for worker_id, p in self.procs.items():
if not p.is_alive():
logger.error("Worker {} has died".format(worker_id))
Expand All @@ -475,7 +457,7 @@ def worker_watchdog(self, kill_event: threading.Event):
logger.critical("Exiting")

@wrap_with_logs
def handle_monitoring_messages(self, kill_event: threading.Event):
def handle_monitoring_messages(self):
"""Transfer messages from the managed monitoring queue to the result queue.

We separate the queues so that the result queue does not rely on a manager
Expand All @@ -489,7 +471,7 @@ def handle_monitoring_messages(self, kill_event: threading.Event):

poll_period_s = max(10, self.poll_period) / 1000 # Must be at least 10 ms

while not kill_event.is_set():
while not self._stop_event.is_set():
try:
logger.debug("Starting monitor_queue.get()")
msg = self.monitoring_queue.get(block=True, timeout=poll_period_s)
Expand All @@ -516,18 +498,16 @@ def start(self):

logger.debug("Workers started")

thr_task_puller = threading.Thread(target=self.pull_tasks,
args=(self._kill_event,),
name="Task-Puller")
thr_result_pusher = threading.Thread(target=self.push_results,
args=(self._kill_event,),
name="Result-Pusher")
thr_worker_watchdog = threading.Thread(target=self.worker_watchdog,
args=(self._kill_event,),
name="worker-watchdog")
thr_monitoring_handler = threading.Thread(target=self.handle_monitoring_messages,
args=(self._kill_event,),
name="Monitoring-Handler")
thr_task_puller = threading.Thread(target=self.pull_tasks, name="Task-Puller")
thr_result_pusher = threading.Thread(
target=self.push_results, name="Result-Pusher"
)
thr_worker_watchdog = threading.Thread(
target=self.worker_watchdog, name="worker-watchdog"
)
thr_monitoring_handler = threading.Thread(
target=self.handle_monitoring_messages, name="Monitoring-Handler"
)

thr_task_puller.start()
thr_result_pusher.start()
Expand All @@ -537,7 +517,7 @@ def start(self):
logger.info("Manager threads started")

# This might need a multiprocessing event to signal back.
self._kill_event.wait()
self._stop_event.wait()
logger.critical("Received kill event, terminating worker processes")

thr_task_puller.join()
Expand Down
Loading