diff --git a/batchspawner/batchspawner.py b/batchspawner/batchspawner.py index 318c7db7..93b1bf27 100644 --- a/batchspawner/batchspawner.py +++ b/batchspawner/batchspawner.py @@ -32,6 +32,7 @@ from jupyterhub.spawner import Spawner from traitlets import Integer, Unicode, Float, Dict, default +from jupyterhub.utils import random_port from jupyterhub.spawner import set_user_setuid @@ -186,6 +187,20 @@ def _req_keepvars_default(self): "specification.", ).tag(config=True) + connect_to_job_cmd = Unicode('', + help="Command to connect to running batch job and forward the port " + "of the running notebook to the Hub. If empty, direct connectivity is assumed. " + "Uses self.job_id as {job_id}, self.port as {port} and self.ip as {host}." + "If {rport} is used in this string, it is set to self.port, " + "and a new random self.port is chosen locally and used as {port}." + "This is useful e.g. for SSH port forwarding." + ).tag(config=True) + + rport = Integer(0, + help="Remote port of notebook, will be set if it differs from the self.port." + "This is set by connect_to_job() if needed." + ) + # Raw output of job submission command unless overridden job_id = Unicode() @@ -215,6 +230,26 @@ def cmd_formatted_for_batch(self): """The command which is substituted inside of the batch script""" return " ".join([self.batchspawner_singleuser_cmd] + self.cmd + self.get_args()) + async def connect_to_job(self): + """This command ensures the port of the singleuser server is reachable from the + Batchspawner machine. Only called if connect_to_job_cmd is set. + If the template string connect_to_job_cmd contains {rport}, + a new random self.port is chosen locally (useful e.g. for SSH port forwarding). + """ + subvars = self.get_req_subvars() + subvars['host'] = self.ip + subvars['job_id'] = self.job_id + if '{rport}' in self.connect_to_job_cmd: + self.rport = self.port + self.port = random_port() + subvars['rport'] = self.rport + subvars['port'] = self.port + else: + subvars['port'] = self.port + cmd = ' '.join((format_template(self.exec_prefix, **subvars), + format_template(self.connect_to_job_cmd, **subvars))) + await self.run_background_command(cmd) + async def run_command(self, cmd, input=None, env=None): proc = await asyncio.create_subprocess_shell( cmd, @@ -268,6 +303,46 @@ async def run_command(self, cmd, input=None, env=None): out = out.decode().strip() return out + # List of running background processes, e.g. used by connect_to_job. + background_processes = [] + + async def _async_wait_process(self, sleep_time): + """Asynchronously sleeping process for delayed checks""" + await asyncio.sleep(sleep_time) + + async def run_background_command(self, cmd, startup_check_delay=1, input=None, env=None): + """Runs the given background command, adds it to background_processes, + and checks if the command is still running after startup_check_delay.""" + background_process = asyncio.ensure_future(self.run_command(cmd, input, env)) + success_check_delay = asyncio.ensure_future(self._async_wait_process(startup_check_delay)) + + # Start up both the success check process and the actual process. + done, pending = await asyncio.wait([background_process, success_check_delay], return_when=asyncio.FIRST_COMPLETED) + + # If the success check process is the one which exited first, all is good, else fail. + if success_check_delay in done: + background_task = list(pending)[0] + self.background_processes.append(background_task) + return background_task + else: + self.log.error("Background command exited early: %s" % cmd) + gather_pending = asyncio.gather(*pending) + gather_pending.cancel() + try: + self.log.debug("Cancelling pending success check task...") + await gather_pending + except asyncio.CancelledError: + self.log.debug("Cancel was successful.") + pass + + # Retrieve exception from "done" process. + try: + gather_done = asyncio.gather(*done) + await gather_done + except: + self.log.debug("Retrieving exception from failed background task...") + raise RuntimeError('{} failed!'.format(cmd)) + async def _get_batch_script(self, **subvars): """Format batch script from vars""" # Could be overridden by subclasses, but mainly useful for testing @@ -299,6 +374,27 @@ async def submit_batch_script(self): self.job_id = "" return self.job_id + def background_tasks_ok(self): + # Check background processes. + if self.background_processes: + self.log.debug('Checking background processes...') + for background_process in self.background_processes: + if background_process.done(): + self.log.debug('Found a background process in state "done"...') + try: + background_exception = background_process.exception() + except asyncio.CancelledError: + self.log.error('Background process was cancelled!') + if background_exception: + self.log.error('Background process exited with an exception:') + self.log.error(background_exception) + self.log.error('At least one background process exited!') + return False + else: + self.log.debug('Found a not-yet-done background process...') + self.log.debug('All background processes still running.') + return True + # Override if your batch system needs something more elaborate to query the job status batch_query_cmd = Unicode( "", @@ -353,6 +449,29 @@ async def cancel_batch_job(self): ) ) self.log.info("Cancelling job " + self.job_id + ": " + cmd) + + if self.background_processes: + self.log.debug('Job being cancelled, cancelling background processes...') + for background_process in self.background_processes: + if not background_process.cancelled(): + try: + background_process.cancel() + except: + self.log.error('Encountered an exception cancelling background process...') + self.log.debug('Cancelled background process, waiting for it to finish...') + try: + await asyncio.wait([background_process]) + except asyncio.CancelledError: + self.log.error('Successfully cancelled background process.') + pass + except: + self.log.error('Background process exited with another exception!') + raise + else: + self.log.debug('Background process already cancelled...') + self.background_processes.clear() + self.log.debug('All background processes cancelled.') + await self.run_command(cmd) def load_state(self, state): @@ -400,6 +519,13 @@ async def poll(self): """Poll the process""" status = await self.query_job_status() if status in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN): + if not self.background_tasks_ok(): + self.log.debug('Going to stop job, since background tasks have failed!') + await self.stop(now=True) + status = await self.query_job_status() + if status not in (JobStatus.PENDING, JobStatus.RUNNING, JobStatus.UNKNOWN): + self.clear_state() + return 1 return None else: self.clear_state() @@ -459,6 +585,14 @@ async def start(self): if hasattr(self, "mock_port"): self.port = self.mock_port + if self.connect_to_job_cmd: + await self.connect_to_job() + + # Port and ip can be changed in connect_to_job, push out to jupyterhub. + if self.server: + self.server.port = self.port + self.server.ip = self.ip + self.db.commit() self.log.info( "Notebook server job {0} started at {1}:{2}".format( @@ -887,6 +1021,7 @@ class CondorSpawner(UserEnvMixin, BatchSpawnerRegexStates): 'condor_q {job_id} -format "%s, " JobStatus -format "%s" RemoteHost -format "\n" True' ).tag(config=True) batch_cancel_cmd = Unicode("condor_rm {job_id}").tag(config=True) + connect_to_job_cmd = Unicode("condor_ssh_to_job -ssh \"ssh -L {port}:localhost:{rport} -oExitOnForwardFailure=yes\" {job_id}").tag(config=True) # job status: 1 = pending, 2 = running state_pending_re = Unicode(r"^1,").tag(config=True) state_running_re = Unicode(r"^2,").tag(config=True) @@ -909,6 +1044,12 @@ def cmd_formatted_for_batch(self): .replace("'", "''") ) + def state_gethost(self): + """Returns localhost if connect_to_job is used, as this forwards the singleuser server port from the spawned job""" + if self.connect_to_job_cmd: + return "localhost" + else: + return super(CondorSpawner,self).state_gethost() class LsfSpawner(BatchSpawnerBase): """A Spawner that uses IBM's Platform Load Sharing Facility (LSF) to launch notebooks.""" diff --git a/batchspawner/tests/test_spawners.py b/batchspawner/tests/test_spawners.py index 2416680f..8cb970fd 100644 --- a/batchspawner/tests/test_spawners.py +++ b/batchspawner/tests/test_spawners.py @@ -560,6 +560,7 @@ def test_condor(db, io_loop): "req_nprocs": "5", "req_memory": "5678", "req_options": "some_option_asdf", + "connect_to_job_cmd": "", } batch_script_re_list = [ re.compile(r"exec batchspawner-singleuser singleuser_command"),