Skip to content

Commit acd5ef9

Browse files
committed
BatchSpawnerBase: Add background_tasks, connect_to_job feature.
This adds the possibility to start a "connect_to_job" background task on the hub on job start, which establishes connectivity to the actual single user server. An example for this can be "condor_ssh_to_job" for HTCondor batch systems. Additionally, the background tasks are monitored: - for successful startup. The background task is given some time to successfully establish connectivity. - in poll() during job runtime and if they fail, the job is terminated.
1 parent 129951a commit acd5ef9

File tree

1 file changed

+112
-0
lines changed

1 file changed

+112
-0
lines changed

batchspawner/batchspawner.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,12 @@ def _req_keepvars_default(self):
172172
"specification."
173173
).tag(config=True)
174174

175+
connect_to_job_cmd = Unicode('',
176+
help="Command to connect to running batch job and forward the port "
177+
"of the running notebook to the Hub. If empty, direct connectivity is assumed. "
178+
"Uses self.job_id as {job_id} and the self.port as {port}."
179+
).tag(config=True)
180+
175181
# Raw output of job submission command unless overridden
176182
job_id = Unicode()
177183

@@ -200,6 +206,18 @@ def cmd_formatted_for_batch(self):
200206
"""The command which is substituted inside of the batch script"""
201207
return ' '.join([self.batchspawner_singleuser_cmd] + self.cmd + self.get_args())
202208

209+
async def connect_to_job(self):
210+
"""This command ensures the port of the singleuser server is reachable from the
211+
Batchspawner machine. By default, it does nothing, i.e. direct connectivity
212+
is assumed.
213+
"""
214+
subvars = self.get_req_subvars()
215+
subvars['job_id'] = self.job_id
216+
subvars['port'] = self.port
217+
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
218+
format_template(self.connect_to_job_cmd, **subvars)))
219+
await self.run_background_command(cmd)
220+
203221
async def run_command(self, cmd, input=None, env=None):
204222
proc = await asyncio.create_subprocess_shell(cmd, env=env,
205223
stdin=asyncio.subprocess.PIPE,
@@ -236,6 +254,46 @@ async def run_command(self, cmd, input=None, env=None):
236254
out = out.decode().strip()
237255
return out
238256

257+
# List of running background processes, e.g. used by connect_to_job.
258+
background_processes = []
259+
260+
async def _async_wait_process(self, sleep_time):
261+
"""Asynchronously sleeping process for delayed checks"""
262+
await asyncio.sleep(sleep_time)
263+
264+
async def run_background_command(self, cmd, startup_check_delay=1, input=None, env=None):
265+
"""Runs the given background command, adds it to background_processes,
266+
and checks if the command is still running after startup_check_delay."""
267+
background_process = self.run_command(cmd, input, env)
268+
success_check_delay = self._async_wait_process(startup_check_delay)
269+
270+
# Start up both the success check process and the actual process.
271+
done, pending = await asyncio.wait([background_process, success_check_delay], return_when=asyncio.FIRST_COMPLETED)
272+
273+
# If the success check process is the one which exited first, all is good, else fail.
274+
if list(done)[0]._coro == success_check_delay:
275+
background_task = list(pending)[0]
276+
self.background_processes.append(background_task)
277+
return background_task
278+
else:
279+
self.log.error("Background command exited early: %s" % cmd)
280+
gather_pending = asyncio.gather(*pending)
281+
gather_pending.cancel()
282+
try:
283+
self.log.debug("Cancelling pending success check task...")
284+
await gather_pending
285+
except asyncio.CancelledError:
286+
self.log.debug("Cancel was successful.")
287+
pass
288+
289+
# Retrieve exception from "done" process.
290+
try:
291+
gather_done = asyncio.gather(*done)
292+
await gather_done
293+
except:
294+
self.log.debug("Retrieving exception from failed background task...")
295+
raise RuntimeError('{} failed!'.format(cmd))
296+
239297
async def _get_batch_script(self, **subvars):
240298
"""Format batch script from vars"""
241299
# Could be overridden by subclasses, but mainly useful for testing
@@ -263,6 +321,27 @@ async def submit_batch_script(self):
263321
self.job_id = ''
264322
return self.job_id
265323

324+
def background_tasks_ok(self):
325+
# Check background processes.
326+
if self.background_processes:
327+
self.log.debug('Checking background processes...')
328+
for background_process in self.background_processes:
329+
if background_process.done():
330+
self.log.debug('Found a background process in state "done"...')
331+
try:
332+
background_exception = background_process.exception()
333+
except asyncio.CancelledError:
334+
self.log.error('Background process was cancelled!')
335+
if background_exception:
336+
self.log.error('Background process exited with an exception:')
337+
self.log.error(background_exception)
338+
self.log.error('At least one background process exited!')
339+
return False
340+
else:
341+
self.log.debug('Found a not-yet-done background process...')
342+
self.log.debug('All background processes still running.')
343+
return True
344+
266345
# Override if your batch system needs something more elaborate to query the job status
267346
batch_query_cmd = Unicode('',
268347
help="Command to run to query job status. Formatted using req_xyz traits as {xyz} "
@@ -307,6 +386,29 @@ async def cancel_batch_job(self):
307386
cmd = ' '.join((format_template(self.exec_prefix, **subvars),
308387
format_template(self.batch_cancel_cmd, **subvars)))
309388
self.log.info('Cancelling job ' + self.job_id + ': ' + cmd)
389+
390+
if self.background_processes:
391+
self.log.debug('Job being cancelled, cancelling background processes...')
392+
for background_process in self.background_processes:
393+
if not background_process.cancelled():
394+
try:
395+
background_process.cancel()
396+
except:
397+
self.log.error('Encountered an exception cancelling background process...')
398+
self.log.debug('Cancelled background process, waiting for it to finish...')
399+
try:
400+
await asyncio.wait([background_process])
401+
except asyncio.CancelledError:
402+
self.log.error('Successfully cancelled background process.')
403+
pass
404+
except:
405+
self.log.error('Background process exited with another exception!')
406+
raise
407+
else:
408+
self.log.debug('Background process already cancelled...')
409+
self.background_processes.clear()
410+
self.log.debug('All background processes cancelled.')
411+
310412
await self.run_command(cmd)
311413

312414
def load_state(self, state):
@@ -354,6 +456,13 @@ async def poll(self):
354456
"""Poll the process"""
355457
status = await self.query_job_status()
356458
if status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
459+
if not self.background_tasks_ok():
460+
self.log.debug('Going to stop job, since background tasks have failed!')
461+
await self.stop(now=True)
462+
status = await self.query_job_status()
463+
if status not in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN):
464+
self.clear_state()
465+
return 1
357466
return None
358467
else:
359468
self.clear_state()
@@ -413,6 +522,9 @@ async def start(self):
413522
self.job_id, self.ip, self.port)
414523
)
415524

525+
if self.connect_to_job_cmd:
526+
await self.connect_to_job()
527+
416528
return self.ip, self.port
417529

418530
async def stop(self, now=False):

0 commit comments

Comments
 (0)