From acc09a5d5cc7f21729a3f2a25a0184d1b0ca04c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E5=AE=B6=E7=91=8B?= <36886416+JiangJiaWei1103@users.noreply.github.com> Date: Sat, 22 Feb 2025 13:10:12 +0800 Subject: [PATCH] Slurm agent fn task (#3150) * Add slurm plugin blank components Signed-off-by: jiangjiawei1103 * feat: Add naive slurm agent create and get with rest api Signed-off-by: jiangjiawei1103 * Use asyncssh instead of REST API Signed-off-by: jiangjiawei1103 * Test ssh communication and run sbatch Signed-off-by: JiaWei Jiang * Add delete method and support slurm job state Signed-off-by: JiaWei Jiang * feat: Submit and run SlurmTask on a remote Slurm cluster Successfully submit and run the user-defined task as a normal python function on a remote Slurm cluster. 1. Inherit from PythonFunctionTask instead of PythonTask 2. Transfer the task module through sftp 3. Interact with amazon s3 bucket on both localhost and Slurm cluster Signed-off-by: JiaWei Jiang * refactor: Remove redundant task_module transfer Specifying `--raw-output-data-prefix` option handles task_module download. Signed-off-by: JiaWei Jiang * refactor: Remove redundant env var Signed-off-by: JiaWei Jiang * docs: Add env setup guide for local test Signed-off-by: JiaWei Jiang * docs: Add links and figures Signed-off-by: JiaWei Jiang * docs: Fix commit sha Signed-off-by: JiaWei Jiang * docs: Fix commit sha for demo guide Signed-off-by: JiaWei Jiang * docs: Fix links Signed-off-by: JiaWei Jiang * feat: Support SSH config in task config Add `ssh_conf` filed to let users specify connection secret Note that reconnection is done in both `get` and `delete`. This is just a temporary workaround. Signed-off-by: JiaWei Jiang * docs: Include ssh config in demo example Signed-off-by: JiaWei Jiang * fix: Retain user-specified file format info Signed-off-by: JiaWei Jiang * fix: Set sdt format based on user-specified file_format Signed-off-by: JiaWei Jiang * Remove redundant modification Signed-off-by: JiaWei Jiang * test: Test file_format attribute alignment in dc.sd Signed-off-by: JiaWei Jiang * refactor: Reduce ssh_conf option to slurm_host only For data scientists and MLEs developing flyte wf with Slurm agent, they don't actually need to know ssh connection details. We assume they only need to specify which Slurm cluster to use by hostname. Signed-off-by: JiaWei Jiang * feat: Support Slurm agent with ShellTask 1. Write user-defined batch script to a tmp file 2. Transfer the batch script through sftp 3. Construct sbatch command to run on Slurm cluster Signed-off-by: JiaWei Jiang * feat: Simplify Slurm job submission logic 1. Remove SFTP for batch script transfer * Assume Slurm batch script is present on Slurm cluster 2. Support directly specifying a remote batch script path Signed-off-by: JiaWei Jiang * Added script args to agent and task Signed-off-by: pryce-turner * Add asyncssh to dependencies Signed-off-by: JiaWei Jiang * docs: Update setup and demo for a basic use case Signed-off-by: JiaWei Jiang * docs: Update basic arch figure path Signed-off-by: JiaWei Jiang * docs: Fix typo and hyperlink Signed-off-by: JiaWei Jiang * fix: A tmp workaround to test agent locally without container_image Signed-off-by: JiaWei Jiang * feat: Support user-defined batch script content with SlurmShellTask `SlurmTask` and `SlurmShellTask` now share the same agent. Signed-off-by: JiaWei Jiang * feat: Fall back to PythonTask for naive use cases 1. Inherited from `PythonTask` for cases in which the batch script is already on the Slurm cluster 2. Use a dummy `Interface` as a tmp workaround Signed-off-by: JiaWei Jiang * refactor: Define Slurm as a base task config and extend for remote script Signed-off-by: JiaWei Jiang * feat: Support PythonFunctionTask and reorganize agent structure 1. Add back `PythonFunctionTask` to support running user-defined functions on Slurm 2. Categorize task types into `script/` and `function/` Signed-off-by: JiaWei Jiang * Use poetry virtual env to avoid contamination Signed-off-by: JiangJiaWei1103 * docs: Complete local test env setup process Signed-off-by: JiangJiaWei1103 * docs: Add use cases ranging from basic to advanced Signed-off-by: JiangJiaWei1103 * feat: Add a script option for the Slurm function task Signed-off-by: JiangJiaWei1103 * fix: Avoid attaching async resource to different event loops Signed-off-by: JiangJiaWei1103 * use await self._connect(slurm_host) in slurm agent Signed-off-by: Future-Outlier * change Signed-off-by: Future-Outlier * print more info Signed-off-by: Future-Outlier * use logger Signed-off-by: Future-Outlier * print more infor Signed-off-by: Future-Outlier * print Signed-off-by: Future-Outlier * Use sbatch for running Slurm function task Signed-off-by: JiangJiaWei1103 * update Signed-off-by: Future-Outlier * push Signed-off-by: Future-Outlier * feat: Show stdout and stderr msg of the Slurm cluster Signed-off-by: JiangJiaWei1103 * feat: Show stdout and stderr msg of the Slurm cluster for SlurmFunctionTask Signed-off-by: JiangJiaWei1103 * feat: Make an SSH connetion based on client config file or ssh_config 1. Make SSH `host` and `username` required fields 2. Support SSH connection based on the default OpenSSH client config file `~/.ssh/config` 3. Support SSH connection via public key auth either by user-specified `client_keys` or the secret for key `FLYTE_SLURM_PRIVATE_KEY` Signed-off-by: JiangJiaWei1103 * Clarify SSH connection logic Signed-off-by: JiangJiaWei1103 * feat: Interpolate the script with dynamic input values Signed-off-by: JiangJiaWei1103 * feat: Interpolate the script with dynamic output values Support passing files across multiple `SlurmShellTask` Signed-off-by: JiangJiaWei1103 * add assertion Signed-off-by: Future-Outlier * update Signed-off-by: Future-Outlier * update Signed-off-by: Future-Outlier * Fix Script agent bug Signed-off-by: Future-Outlier * agent service for shell task Signed-off-by: Future-Outlier * Remove remote path to avoid race condition Signed-off-by: Future-Outlier * Revert agent server change Signed-off-by: Future-Outlier * use key val to run ssh config Signed-off-by: Future-Outlier * update Signed-off-by: Future-Outlier * use _get_or_create_ssh_connection Signed-off-by: Future-Outlier * update Signed-off-by: Future-Outlier * use SlurmCluster and hash Signed-off-by: Future-Outlier * updagte Signed-off-by: Future-Outlier * update Signed-off-by: Future-Outlier * update Signed-off-by: Future-Outlier * refactor: Simplify validation process and clean up legacy code 1. Ensure `"host"` must be provided in `__post_init__` 2. Explicitly set `known_hosts` to `None` 3. Make `username` optional 4. Remove legacy code snippets 5. Make docstring clear Signed-off-by: JiangJiaWei1103 * Add Slurm agent function task Signed-off-by: JiangJiaWei1103 * Revert ShellTask behavior Signed-off-by: JiangJiaWei1103 * Remove fix for SlurmShellTask Signed-off-by: JiangJiaWei1103 * Remove blank line Signed-off-by: JiangJiaWei1103 * fix doc string and remove logs Signed-off-by: Future-Outlier * build plugins Signed-off-by: Future-Outlier * merge master Signed-off-by: Future-Outlier * fix-sage-maker-test Signed-off-by: Future-Outlier * add test_slurm_fn_task Signed-off-by: Future-Outlier * fix Signed-off-by: Future-Outlier * update flytebot Signed-off-by: Future-Outlier * add know host = None Signed-off-by: Future-Outlier --------- Signed-off-by: jiangjiawei1103 Signed-off-by: JiaWei Jiang Signed-off-by: pryce-turner Signed-off-by: JiangJiaWei1103 Signed-off-by: Future-Outlier Co-authored-by: pryce-turner Co-authored-by: Future-Outlier --- .github/workflows/pythonbuild.yml | 1 + flytekit/extend/backend/utils.py | 4 +- .../tests/test_boto3_mixin.py | 6 +- plugins/flytekit-slurm/README.md | 5 + .../flytekitplugins/slurm/__init__.py | 2 + .../flytekitplugins/slurm/function/agent.py | 190 ++++++++++++++++++ .../flytekitplugins/slurm/function/task.py | 83 ++++++++ .../flytekitplugins/slurm/ssh_utils.py | 131 ++++++++++++ plugins/flytekit-slurm/setup.py | 40 ++++ plugins/flytekit-slurm/tests/__init__.py | 0 .../tests/test_slurm_fn_task.py | 55 +++++ 11 files changed, 512 insertions(+), 5 deletions(-) create mode 100644 plugins/flytekit-slurm/README.md create mode 100644 plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py create mode 100644 plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py create mode 100644 plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py create mode 100644 plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py create mode 100644 plugins/flytekit-slurm/setup.py create mode 100644 plugins/flytekit-slurm/tests/__init__.py create mode 100644 plugins/flytekit-slurm/tests/test_slurm_fn_task.py diff --git a/.github/workflows/pythonbuild.yml b/.github/workflows/pythonbuild.yml index 912a0a01c6..bed83af5c5 100644 --- a/.github/workflows/pythonbuild.yml +++ b/.github/workflows/pythonbuild.yml @@ -337,6 +337,7 @@ jobs: - flytekit-papermill - flytekit-polars - flytekit-ray + - flytekit-slurm - flytekit-snowflake - flytekit-spark - flytekit-sqlalchemy diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index 4dcdf3174a..8aa7952134 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -20,9 +20,9 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: Convert the state from the agent to the phase in flyte. """ state = state.lower() - if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]: + if state in ["failed", "timeout", "timedout", "canceled", "cancelled", "skipped", "internal_error"]: return TaskExecution.FAILED - elif state in ["done", "succeeded", "success"]: + elif state in ["done", "succeeded", "success", "completed"]: return TaskExecution.SUCCEEDED elif state in ["running", "terminating"]: return TaskExecution.RUNNING diff --git a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py index 304ae49a01..1e12ffb227 100644 --- a/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py +++ b/plugins/flytekit-aws-sagemaker/tests/test_boto3_mixin.py @@ -5,7 +5,7 @@ from flytekitplugins.awssagemaker_inference import triton_image_uri from flytekitplugins.awssagemaker_inference.boto3_mixin import ( Boto3AgentMixin, - update_dict_fn, + format_dict, ) from flytekit import FlyteContext, StructuredDataset @@ -50,7 +50,7 @@ def test_inputs(): }, ) - result = update_dict_fn( + result = format_dict( service="s3", original_dict=original_dict, update_dict={"inputs": literal_map_string_repr(inputs)}, @@ -75,7 +75,7 @@ def test_container(): original_dict = {"a": "{images.primary_container_image}"} images = {"primary_container_image": "cr.flyte.org/flyteorg/flytekit:py3.11-1.10.3"} - result = update_dict_fn( + result = format_dict( service="sagemaker", original_dict=original_dict, update_dict={"images": images} ) diff --git a/plugins/flytekit-slurm/README.md b/plugins/flytekit-slurm/README.md new file mode 100644 index 0000000000..af6596cf28 --- /dev/null +++ b/plugins/flytekit-slurm/README.md @@ -0,0 +1,5 @@ +# Flytekit Slurm Plugin + +The Slurm agent is designed to integrate Flyte workflows with Slurm-managed high-performance computing (HPC) clusters, enabling users to leverage Slurm's capability of compute resource allocation, scheduling, and monitoring. + +This [guide](https://github.com/JiangJiaWei1103/flytekit/blob/slurm-agent-dev/plugins/flytekit-slurm/demo.md) provides a concise overview of the design philosophy behind the Slurm agent and explains how to set up a local environment for testing the agent. diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py b/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py new file mode 100644 index 0000000000..c841712de6 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/__init__.py @@ -0,0 +1,2 @@ +from .function.agent import SlurmFunctionAgent +from .function.task import SlurmFunction, SlurmFunctionTask diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py b/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py new file mode 100644 index 0000000000..b1799a7ae3 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/function/agent.py @@ -0,0 +1,190 @@ +import tempfile +import uuid +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +from asyncssh import SSHClientConnection + +from flytekit import logger +from flytekit.extend.backend.base_agent import AgentRegistry, AsyncAgentBase, Resource, ResourceMeta +from flytekit.extend.backend.utils import convert_to_flyte_phase +from flytekit.models.literals import LiteralMap +from flytekit.models.task import TaskTemplate + +from ..ssh_utils import ssh_connect + + +@dataclass +class SlurmJobMetadata(ResourceMeta): + """Slurm job metadata. + + Args: + job_id: Slurm job id. + ssh_config: Options of SSH client connection. For available options, please refer to + the ssh_utils module. + + Attributes: + job_id (str): Slurm job id. + ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration options + for establishing client connections. + """ + + job_id: str + ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] + + +@dataclass +class SlurmCluster: + host: str + username: Optional[str] = None + + def __hash__(self): + return hash((self.host, self.username)) + + +class SlurmFunctionAgent(AsyncAgentBase): + name = "Slurm Function Agent" + + # SSH connection pool for multi-host environment + ssh_config_to_ssh_conn: Dict[SlurmCluster, SSHClientConnection] = {} + + def __init__(self) -> None: + super(SlurmFunctionAgent, self).__init__(task_type_name="slurm_fn", metadata_type=SlurmJobMetadata) + + async def create( + self, + task_template: TaskTemplate, + inputs: Optional[LiteralMap] = None, + **kwargs, + ) -> SlurmJobMetadata: + unique_script_name = f"/tmp/task_{uuid.uuid4().hex}.slurm" + + # Retrieve task config + ssh_config = task_template.custom["ssh_config"] + sbatch_conf = task_template.custom["sbatch_conf"] + script = task_template.custom["script"] + + # Construct command for Slurm cluster + cmd, script = _get_sbatch_cmd_and_script( + sbatch_conf=sbatch_conf, + entrypoint=" ".join(task_template.container.args), + script=script, + batch_script_path=unique_script_name, + ) + + # Run Slurm job + conn = await self._get_or_create_ssh_connection(ssh_config) + with tempfile.NamedTemporaryFile("w") as f: + f.write(script) + f.flush() + async with conn.start_sftp_client() as sftp: + await sftp.put(f.name, unique_script_name) + res = await conn.run(cmd, check=True) + + # Retrieve Slurm job id + job_id = res.stdout.split()[-1] + + return SlurmJobMetadata(job_id=job_id, ssh_config=ssh_config) + + async def get(self, resource_meta: SlurmJobMetadata, **kwargs) -> Resource: + ssh_config = resource_meta.ssh_config + conn = await self._get_or_create_ssh_connection(ssh_config) + job_res = await conn.run(f"scontrol show job {resource_meta.job_id}", check=True) + + # Determine the current flyte phase from Slurm job state + job_state = "running" + msg = "No stdout available" + for o in job_res.stdout.split(" "): + if "JobState" in o: + job_state = o.split("=")[1].strip().lower() + elif "StdOut" in o: + stdout_path = o.split("=")[1].strip() + msg_res = await conn.run(f"cat {stdout_path}", check=True) + msg = msg_res.stdout + + cur_phase = convert_to_flyte_phase(job_state) + + return Resource(phase=cur_phase, message=msg) + + async def delete(self, resource_meta: SlurmJobMetadata, **kwargs) -> None: + conn = await self._get_or_create_ssh_connection(resource_meta.ssh_config) + _ = await conn.run(f"scancel {resource_meta.job_id}", check=True) + + async def _get_or_create_ssh_connection( + self, ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] + ) -> SSHClientConnection: + """Get an existing SSH connection or create a new one if needed. + + Args: + ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH configuration dictionary. + + Returns: + SSHClientConnection: An active SSH connection, either pre-existing or newly established. + """ + host = ssh_config.get("host") + username = ssh_config.get("username") + + ssh_cluster_config = SlurmCluster(host=host, username=username) + if self.ssh_config_to_ssh_conn.get(ssh_cluster_config) is None: + logger.info("ssh connection key not found, creating new connection") + conn = await ssh_connect(ssh_config=ssh_config) + self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn + else: + conn = self.ssh_config_to_ssh_conn[ssh_cluster_config] + try: + await conn.run("echo [TEST] SSH connection", check=True) + logger.info("re-using new connection") + except Exception as e: + logger.info(f"Re-establishing SSH connection due to error: {e}") + conn = await ssh_connect(ssh_config=ssh_config) + self.ssh_config_to_ssh_conn[ssh_cluster_config] = conn + + return conn + + +def _get_sbatch_cmd_and_script( + sbatch_conf: Dict[str, str], + entrypoint: str, + script: Optional[str] = None, + batch_script_path: str = "/tmp/task.slurm", +) -> Tuple[str, str]: + """Construct the Slurm sbatch command and the batch script content. + + Flyte entrypoint, pyflyte-execute, is run within a bash shell process. + + Args: + sbatch_conf (Dict[str, str]): Options of sbatch command. + entrypoint (str): Flyte entrypoint. + script (Optional[str], optional): User-defined script where "{task.fn}" serves as a placeholder for the + task function execution. Users should insert "{task.fn}" at the desired + execution point within the script. If the script is not provided, the task + function will be executed directly. Defaults to None. + batch_script_path (str, optional): Absolute path of the batch script on Slurm cluster. + Defaults to "/tmp/task.slurm". + + Returns: + Tuple[str, str]: A tuple containing: + - cmd: Slurm sbatch command + - script: The batch script content + """ + # Setup sbatch options + cmd = ["sbatch"] + for opt, val in sbatch_conf.items(): + cmd.extend([f"--{opt}", str(val)]) + + # Assign the batch script to run + cmd.append(batch_script_path) + + if script is None: + script = f"""#!/bin/bash -i + {entrypoint} + """ + else: + script = script.replace("{task.fn}", entrypoint) + + cmd = " ".join(cmd) + + return cmd, script + + +AgentRegistry.register(SlurmFunctionAgent()) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py b/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py new file mode 100644 index 0000000000..9c9220f3e3 --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/function/task.py @@ -0,0 +1,83 @@ +""" +Slurm task. +""" + +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +from flytekit import FlyteContextManager, PythonFunctionTask +from flytekit.configuration import SerializationSettings +from flytekit.extend import TaskPlugins +from flytekit.extend.backend.base_agent import AsyncAgentExecutorMixin +from flytekit.image_spec import ImageSpec + + +@dataclass +class SlurmFunction(object): + """Configure Slurm settings. Note that we focus on sbatch command now. + + Args: + ssh_config: Options of SSH client connection. For available options, please refer to + + sbatch_conf: Options of sbatch command. If not provided, defaults to an empty dict. + script: User-defined script where "{task.fn}" serves as a placeholder for the + task function execution. Users should insert "{task.fn}" at the desired + execution point within the script. If the script is not provided, the task + function will be executed directly. + + Attributes: + ssh_config (Dict[str, Union[str, List[str], Tuple[str, ...]]]): SSH client configuration options. + sbatch_conf (Optional[Dict[str, str]]): Slurm sbatch command options. + script (Optional[str]): Custom script template for task execution. + """ + + ssh_config: Dict[str, Union[str, List[str], Tuple[str, ...]]] + sbatch_conf: Optional[Dict[str, str]] = None + script: Optional[str] = None + + def __post_init__(self): + assert self.ssh_config["host"] is not None, "'host' must be specified in ssh_config." + if self.sbatch_conf is None: + self.sbatch_conf = {} + + +class SlurmFunctionTask(AsyncAgentExecutorMixin, PythonFunctionTask[SlurmFunction]): + """ + Actual Plugin that transforms the local python code for execution within a slurm context... + """ + + _TASK_TYPE = "slurm_fn" + + def __init__( + self, + task_config: SlurmFunction, + task_function: Callable, + container_image: Optional[Union[str, ImageSpec]] = None, + **kwargs, + ): + super(SlurmFunctionTask, self).__init__( + task_config=task_config, + task_type=self._TASK_TYPE, + task_function=task_function, + container_image=container_image, + **kwargs, + ) + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + return { + "ssh_config": self.task_config.ssh_config, + "sbatch_conf": self.task_config.sbatch_conf, + "script": self.task_config.script, + } + + def execute(self, **kwargs) -> Any: + ctx = FlyteContextManager.current_context() + if ctx.execution_state and ctx.execution_state.is_local_execution(): + # Mimic the propeller's behavior in local agent test + return AsyncAgentExecutorMixin.execute(self, **kwargs) + else: + # Execute the task with a direct python function call + return PythonFunctionTask.execute(self, **kwargs) + + +TaskPlugins.register_pythontask_plugin(SlurmFunction, SlurmFunctionTask) diff --git a/plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py b/plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py new file mode 100644 index 0000000000..98ab6a066d --- /dev/null +++ b/plugins/flytekit-slurm/flytekitplugins/slurm/ssh_utils.py @@ -0,0 +1,131 @@ +""" +Utilities of asyncssh connections. +""" + +import os +import sys +from dataclasses import asdict, dataclass +from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union + +import asyncssh +from asyncssh import SSHClientConnection + +from flytekit import logger +from flytekit.extend.backend.utils import get_agent_secret + +T = TypeVar("T", bound="SSHConfig") +SLURM_PRIVATE_KEY = "FLYTE_SLURM_PRIVATE_KEY" + + +@dataclass(frozen=True) +class SSHConfig: + """A customized version of SSHClientConnectionOptions, tailored to specific needs. + + This config is based on the official SSHClientConnectionOptions but includes + only a subset of options, with some fields adjusted to be optional or required. + For the official options, please refer to: + https://asyncssh.readthedocs.io/en/latest/api.html#asyncssh.SSHClientConnectionOptions + + Attributes: + host (str): The hostname or address to connect to. + username (Optional[str]): The username to authenticate as on the server. + client_keys (Union[str, List[str], Tuple[str, ...]]): File paths to private keys which will be used to authenticate the + client via public key authentication. The default value is an empty tuple since + client public key authentication is mandatory. + """ + + host: str + username: Optional[str] = None + client_keys: Union[str, List[str], Tuple[str, ...]] = () + + @classmethod + def from_dict(cls: Type[T], ssh_config: Dict[str, Any]) -> T: + return cls(**ssh_config) + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + def __eq__(self, other): + if not isinstance(other, SSHConfig): + return False + return self.host == other.host and self.username == other.username and self.client_keys == other.client_keys + + +async def ssh_connect(ssh_config: Dict[str, Any]) -> SSHClientConnection: + """Make an SSH client connection. + + Args: + ssh_config (Dict[str, Any]): Options of SSH client connection defined in SSHConfig. + + Returns: + SSHClientConnection: An SSH client connection object. + + Raises: + ValueError: If both FLYTE_SLURM_PRIVATE_KEY secret and ssh_config['private_key'] are missing. + """ + # Validate ssh_config + ssh_config = SSHConfig.from_dict(ssh_config).to_dict() + # This is required to avoid the error "asyncssh.misc.HostKeyNotVerifiable" when connecting to a new host. + ssh_config["known_hosts"] = None + + # Make the first SSH connection using either OpenSSH client config files or + # a user-defined private key. If using OpenSSH config, it will attempt to + # load settings from ~/.ssh/config. + try: + conn = await asyncssh.connect(**ssh_config) + return conn + except Exception as e: + logger.info( + "Failed to make an SSH connection using the default OpenSSH client config (~/.ssh/config) or " + f"the provided private keys. Error details:\n{e}" + ) + + try: + default_client_key = get_agent_secret(secret_key=SLURM_PRIVATE_KEY) + except ValueError: + logger.info("The secret for key FLYTE_SLURM_PRIVATE_KEY is not set.") + default_client_key = None + + if default_client_key is None and ssh_config.get("client_keys") == (): + raise ValueError( + "Both the secret for key FLYTE_SLURM_PRIVATE_KEY and ssh_config['private_key'] are missing. " + "At least one must be set." + ) + + client_keys = [] + if default_client_key is not None: + # Write the private key to a local path + # This may not be a good practice... + private_key_path = os.path.abspath("./slurm_private_key") + with open(private_key_path, "w") as f: + f.write(default_client_key) + client_keys.append(private_key_path) + + user_client_keys = ssh_config.get("client_keys") + if user_client_keys is not None: + client_keys.extend([user_client_keys] if isinstance(user_client_keys, str) else user_client_keys) + + ssh_config["client_keys"] = client_keys + logger.info(f"Updated SSH config: {ssh_config}") + try: + conn = await asyncssh.connect(**ssh_config) + return conn + except Exception as e: + logger.info( + "Failed to make an SSH connection using the provided private keys. Please verify your setup." + f"Error details:\n{e}" + ) + sys.exit(1) + + +if __name__ == "__main__": + import asyncio + + async def test_connect(): + conn = await ssh_connect({"host": "aws2", "username": "ubuntu"}) + res = await conn.run("echo [TEST] SSH connection", check=True) + out = res.stdout + + return out + + logger.info(asyncio.run(test_connect())) diff --git a/plugins/flytekit-slurm/setup.py b/plugins/flytekit-slurm/setup.py new file mode 100644 index 0000000000..a55e8ca090 --- /dev/null +++ b/plugins/flytekit-slurm/setup.py @@ -0,0 +1,40 @@ +from setuptools import setup + +PLUGIN_NAME = "slurm" + +microlib_name = f"flytekitplugins-{PLUGIN_NAME}" + +plugin_requires = ["flytekit>=1.15.0", "flyteidl>=1.15.0", "asyncssh"] + +__version__ = "0.0.0+develop" + +setup( + name=microlib_name, + version=__version__, + author="flyteorg", + author_email="admin@flyte.org", + description="This package holds the Slurm plugins for flytekit", + namespace_packages=["flytekitplugins"], + packages=[ + f"flytekitplugins.{PLUGIN_NAME}", + f"flytekitplugins.{PLUGIN_NAME}.function", + ], + install_requires=plugin_requires, + license="apache2", + python_requires=">=3.9", + classifiers=[ + "Intended Audience :: Science/Research", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + ], + entry_points={"flytekit.plugins": [f"{PLUGIN_NAME}=flytekitplugins.{PLUGIN_NAME}"]}, +) diff --git a/plugins/flytekit-slurm/tests/__init__.py b/plugins/flytekit-slurm/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/plugins/flytekit-slurm/tests/test_slurm_fn_task.py b/plugins/flytekit-slurm/tests/test_slurm_fn_task.py new file mode 100644 index 0000000000..64539c7417 --- /dev/null +++ b/plugins/flytekit-slurm/tests/test_slurm_fn_task.py @@ -0,0 +1,55 @@ +import os.path +from unittest import mock +from flytekit.core import context_manager +import flytekit +from flytekit import StructuredDataset, StructuredDatasetTransformerEngine, task, ImageSpec +from flytekit.configuration import Image, ImageConfig, SerializationSettings, FastSerializationSettings, DefaultImages +from flytekit.core.context_manager import ExecutionParameters, FlyteContextManager, ExecutionState +from flytekitplugins.slurm import SlurmFunction + + +def test_slurm_task(): + script_file = """#!/bin/bash -i + + echo Run function with sbatch... + + # Run the user-defined task function + {task.fn} + """ + + @task( + # container_image=image, + task_config=SlurmFunction( + ssh_config={ + "host": "your-slurm-host", + "username": "ubuntu", + }, + sbatch_conf={ + "partition": "debug", + "job-name": "tiny-slurm", + "output": "/home/ubuntu/fn_task.log" + }, + script=script_file + ) + ) + def plus_one(x: int) -> int: + return x + 1 + + assert plus_one.task_config is not None + assert plus_one.task_config.ssh_config == {"host": "your-slurm-host", "username": "ubuntu"} + assert plus_one.task_config.sbatch_conf == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"} + assert plus_one.task_config.script == script_file + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + + retrieved_settings = plus_one.get_custom(settings) + assert retrieved_settings["ssh_config"] == {"host": "your-slurm-host", "username": "ubuntu"} + assert retrieved_settings["sbatch_conf"] == {"partition": "debug", "job-name": "tiny-slurm", "output": "/home/ubuntu/fn_task.log"} + assert retrieved_settings["script"] == script_file