diff --git a/flytekit/core/constants.py b/flytekit/core/constants.py index a80ed0f9e4..2ed491ce60 100644 --- a/flytekit/core/constants.py +++ b/flytekit/core/constants.py @@ -11,6 +11,8 @@ START_NODE_ID = "start-node" END_NODE_ID = "end-node" +DEFAULT_FAILURE_NODE_ID = "efn" + # If set this environment variable overrides the default container image and the default base image in ImageSpec. FLYTE_INTERNAL_IMAGE_ENV_VAR = "FLYTE_INTERNAL_IMAGE" diff --git a/flytekit/core/node_creation.py b/flytekit/core/node_creation.py index 791480435f..3699078bf4 100644 --- a/flytekit/core/node_creation.py +++ b/flytekit/core/node_creation.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from typing import TYPE_CHECKING, Union from flytekit.core.base_task import PythonTask @@ -10,6 +11,7 @@ from flytekit.core.workflow import WorkflowBase from flytekit.exceptions import user as _user_exceptions from flytekit.loggers import logger +from flytekit.utils.asyn import run_sync if TYPE_CHECKING: from flytekit.remote.remote_callable import RemoteEntity @@ -77,7 +79,10 @@ def create_node( # When compiling, calling the entity will create a node. ctx = FlyteContext.current_context() if ctx.compilation_state is not None and ctx.compilation_state.mode == 1: - outputs = entity(**kwargs) + if inspect.iscoroutinefunction(entity.__call__): + outputs = run_sync(entity, **kwargs) + else: + outputs = entity(**kwargs) # This is always the output of create_and_link_node which returns create_task_output, which can be # VoidPromise, Promise, or our custom namedtuple of Promises. node = ctx.compilation_state.nodes[-1] diff --git a/flytekit/core/python_function_task.py b/flytekit/core/python_function_task.py index 4544c435b7..c305c0f2a2 100644 --- a/flytekit/core/python_function_task.py +++ b/flytekit/core/python_function_task.py @@ -20,6 +20,8 @@ import inspect import os import signal +import time +import typing from abc import ABC from collections import OrderedDict from contextlib import suppress @@ -32,7 +34,7 @@ from flytekit.core.constants import EAGER_ROOT_ENV_NAME from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager from flytekit.core.docstring import Docstring -from flytekit.core.interface import transform_function_to_interface +from flytekit.core.interface import Interface, transform_function_to_interface from flytekit.core.promise import ( Promise, VoidPromise, @@ -59,11 +61,16 @@ from flytekit.models import dynamic_job as _dynamic_job from flytekit.models import literals as _literal_models from flytekit.models import task as task_models +from flytekit.models.admin import common as admin_common_models from flytekit.models.admin import workflow as admin_workflow_models +from flytekit.models.filters import ValueIn +from flytekit.models.literals import LiteralMap from flytekit.utils.asyn import loop_manager T = TypeVar("T") +CLEANUP_LOOP_DELAY_SECONDS = 1 + class PythonInstanceTask(PythonAutoContainerTask[T], ABC): # type: ignore """ @@ -636,3 +643,118 @@ def run(self, remote: "FlyteRemote", ss: SerializationSettings, **kwargs): # ty with FlyteContextManager.with_context(builder): return loop_manager.run_sync(self.async_execute, self, **kwargs) + + def get_as_workflow(self): + from flytekit.core.workflow import ImperativeWorkflow + + cleanup = EagerFailureHandlerTask(name=f"{self.name}-cleanup", inputs=self.python_interface.inputs) + wb = ImperativeWorkflow(name=self.name) + + input_kwargs = {} + for input_name, input_python_type in self.python_interface.inputs.items(): + wb.add_workflow_input(input_name, input_python_type) + input_kwargs[input_name] = wb.inputs[input_name] + + node = wb.add_entity(self, **input_kwargs) + for output_name, output_python_type in self.python_interface.outputs.items(): + wb.add_workflow_output(output_name, node.outputs[output_name]) + + wb.add_on_failure_handler(cleanup) + return wb + + +class EagerFailureTaskResolver(TaskResolverMixin): + @property + def location(self) -> str: + return f"{EagerFailureTaskResolver.__module__}.eager_failure_task_resolver" + + def name(self) -> str: + return "eager_failure_task_resolver" + + def load_task(self, loader_args: List[str]) -> Task: + """ + Given the set of identifier keys, should return one Python Task or raise an error if not found + """ + return EagerFailureHandlerTask(name="no_input_default_cleanup_task", inputs={}) + + def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]: + """ + Return a list of strings that can help identify the parameter Task + """ + return ["eager", "failure", "handler"] + + def get_all_tasks(self) -> List[Task]: + """ + Future proof method. Just making it easy to access all tasks (Not required today as we auto register them) + """ + return [] + + +eager_failure_task_resolver = EagerFailureTaskResolver() + + +class EagerFailureHandlerTask(PythonAutoContainerTask, metaclass=FlyteTrackedABC): + _TASK_TYPE = "eager_failure_handler_task" + + def __init__(self, name: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, **kwargs): + """ """ + inputs = inputs or {} + super().__init__( + task_type=self._TASK_TYPE, + name=name, + interface=Interface(inputs=inputs, outputs=None), + task_config=None, + task_resolver=eager_failure_task_resolver, + **kwargs, + ) + + def dispatch_execute(self, ctx: FlyteContext, input_literal_map: LiteralMap) -> LiteralMap: + """ + This task should only be called during remote execution. Because when rehydrating this task at execution + time, we don't have access to the python interface of the corresponding eager task/workflow, we don't + have the Python types to convert the input literal map, but nor do we need them. + This task is responsible only for ensuring that all executions are terminated. + """ + # Recursive imports + from flytekit import current_context + from flytekit.configuration.plugin import get_plugin + + most_recent = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING) + current_exec_id = current_context().execution_id + project = current_exec_id.project + domain = current_exec_id.domain + name = current_exec_id.name + logger.warning(f"Cleaning up potentially still running tasks for execution {name} in {project}/{domain}") + try: + remote = get_plugin().get_remote(config=None, project=project, domain=domain) + except Exception as e: + print(e, flush=True) + import sys + + sys.exit(1) + key_filter = ValueIn("execution_tag.key", ["eager-exec"]) + value_filter = ValueIn("execution_tag.value", [name]) + phase_filter = ValueIn("phase", ["UNDEFINED", "QUEUED", "RUNNING"]) + # This should be made more robust, currently lacking retries and exception handling + while True: + exec_models, _ = remote.client.list_executions_paginated( + project, + domain, + limit=100, + filters=[key_filter, value_filter, phase_filter], + sort_by=most_recent, + ) + logger.warning(f"Found {len(exec_models)} executions this round for termination") + if not exec_models: + break + logger.warning(exec_models) + for exec_model in exec_models: + logger.warning(f"Terminating execution {exec_model.id}, phase {exec_model.closure.phase}") + remote.client.terminate_execution(exec_model.id, f"clean up by parent eager execution {name}") + time.sleep(CLEANUP_LOOP_DELAY_SECONDS) + + # Just echo back + return input_literal_map + + def execute(self, **kwargs) -> Any: + raise AssertionError("this task shouldn't need to call execute") diff --git a/flytekit/core/workflow.py b/flytekit/core/workflow.py index aa213502bb..841a7f60b8 100644 --- a/flytekit/core/workflow.py +++ b/flytekit/core/workflow.py @@ -634,6 +634,40 @@ def add_launch_plan(self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwar def add_subwf(self, sub_wf: WorkflowBase, **kwargs) -> Node: return self.add_entity(sub_wf, **kwargs) + def add_on_failure_handler(self, entity): + """ + This is a special function that mimics the add_entity function, but this is only used + to add the failure node. Failure nodes are special because we don't want + them to be part of the main workflow. + """ + from flytekit.core.node_creation import create_node + + ctx = FlyteContext.current_context() + if ctx.compilation_state is not None: + raise RuntimeError("Can't already be compiling") + with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx: + if entity.python_interface and self.python_interface: + workflow_inputs = self.python_interface.inputs + failure_node_inputs = entity.python_interface.inputs + + # Workflow inputs should be a subset of failure node inputs. + if (failure_node_inputs | workflow_inputs) != failure_node_inputs: + raise FlyteFailureNodeInputMismatchException(self.on_failure, self) + additional_keys = failure_node_inputs.keys() - workflow_inputs.keys() + # Raising an error if the additional inputs in the failure node are not optional. + for k in additional_keys: + if not is_optional_type(failure_node_inputs[k]): + raise FlyteFailureNodeInputMismatchException(self.on_failure, self) + + n = create_node(entity=entity, **self._inputs) + # Maybe this can be cleaned up, but the create node function creates a node + # and add it to the compilation state. We need to pop it off because we don't + # want it in the actual workflow. + ctx.compilation_state.nodes.pop(-1) + self._failure_node = n + n._id = _common_constants.DEFAULT_FAILURE_NODE_ID + return n # type: ignore + def ready(self) -> bool: """ This function returns whether or not the workflow is in a ready state, which means diff --git a/flytekit/models/filters.py b/flytekit/models/filters.py index 5d7bb55104..3f994ea2e7 100644 --- a/flytekit/models/filters.py +++ b/flytekit/models/filters.py @@ -64,6 +64,8 @@ def from_python_std(cls, string): return Contains._parse_from_string(string) elif string.startswith("value_in("): return ValueIn._parse_from_string(string) + elif string.startswith("value_not_in("): + return ValueNotIn._parse_from_string(string) else: raise ValueError("'{}' could not be parsed into a filter.".format(string)) @@ -133,3 +135,7 @@ class Contains(SetFilter): class ValueIn(SetFilter): _comparator = "value_in" + + +class ValueNotIn(SetFilter): + _comparator = "value_not_in" diff --git a/flytekit/tools/serialize_helpers.py b/flytekit/tools/serialize_helpers.py index 8d4cfcb99c..5edb66dbb5 100644 --- a/flytekit/tools/serialize_helpers.py +++ b/flytekit/tools/serialize_helpers.py @@ -9,6 +9,7 @@ from flytekit import LaunchPlan from flytekit.core import context_manager as flyte_context from flytekit.core.base_task import PythonTask +from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask from flytekit.core.workflow import WorkflowBase from flytekit.models import launch_plan as _launch_plan_models from flytekit.models import task as task_models @@ -60,6 +61,11 @@ def get_registrable_entities( lp = LaunchPlan.get_default_launch_plan(ctx, entity) get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options) + if isinstance(entity, EagerAsyncPythonFunctionTask): + wf = entity.get_as_workflow() + lp = LaunchPlan.get_default_launch_plan(ctx, wf) + get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options) + new_api_model_values = list(new_api_serializable_entities.values()) entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values)) diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index a295f75078..6ec91be970 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -229,6 +229,10 @@ def get_serializable_workflow( if n.id == _common_constants.GLOBAL_INPUT_NODE_ID: continue + # Ensure no node is named the failure node id + if n.id == _common_constants.DEFAULT_FAILURE_NODE_ID: + raise ValueError(f"Node {n.id} is reserved for the failure node") + # Recursively serialize the node serialized_nodes.append(get_serializable(entity_mapping, settings, n, options)) diff --git a/tests/flytekit/unit/core/test_eager_cleanup.py b/tests/flytekit/unit/core/test_eager_cleanup.py new file mode 100644 index 0000000000..914800b4c3 --- /dev/null +++ b/tests/flytekit/unit/core/test_eager_cleanup.py @@ -0,0 +1,33 @@ +from collections import OrderedDict + +import flytekit.configuration +from flytekit.configuration import Image, ImageConfig +from flytekit.core.python_function_task import EagerFailureHandlerTask +from flytekit.tools.translator import get_serializable + +default_img = Image(name="default", fqn="test", tag="tag") +serialization_settings = flytekit.configuration.SerializationSettings( + project="project", + domain="domain", + version="version", + env=None, + image_config=ImageConfig(default_image=default_img, images=[default_img]), +) + + +def test_failure(): + t = EagerFailureHandlerTask(name="tester", inputs={"a": int}) + + spec = get_serializable(OrderedDict(), serialization_settings, t) + print(spec) + + assert spec.template.container.args == ['pyflyte-execute', '--inputs', '{{.input}}', '--output-prefix', '{{.outputPrefix}}', '--raw-output-data-prefix', '{{.rawOutputDataPrefix}}', '--checkpoint-path', '{{.checkpointOutputPrefix}}', '--prev-checkpoint', '{{.prevCheckpointPrefix}}', '--resolver', 'flytekit.core.python_function_task.eager_failure_task_resolver', '--', 'eager', 'failure', 'handler'] + + +def test_loading(): + from flytekit.tools.module_loader import load_object_from_module + + resolver = load_object_from_module("flytekit.core.python_function_task.eager_failure_task_resolver") + print(resolver) + t = resolver.load_task([]) + assert isinstance(t, EagerFailureHandlerTask) diff --git a/tests/flytekit/unit/core/test_imperative.py b/tests/flytekit/unit/core/test_imperative.py index aee88e19d1..a8fe746a6d 100644 --- a/tests/flytekit/unit/core/test_imperative.py +++ b/tests/flytekit/unit/core/test_imperative.py @@ -3,19 +3,21 @@ from collections import OrderedDict import pytest - +from dataclasses import dataclass, fields, field import flytekit.configuration from flytekit.configuration import Image, ImageConfig from flytekit.core.base_task import kwtypes from flytekit.core.launch_plan import LaunchPlan from flytekit.core.task import reference_task, task from flytekit.core.workflow import ImperativeWorkflow, get_promise, workflow +from flytekit.core.python_function_task import EagerFailureHandlerTask from flytekit.exceptions.user import FlyteValidationException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task from flytekit.models import literals as literal_models from flytekit.tools.translator import get_serializable from flytekit.types.file import FlyteFile from flytekit.types.schema import FlyteSchema +from flytekit.models.admin.workflow import WorkflowSpec default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -137,6 +139,59 @@ def t1(a: typing.Dict[str, typing.List[int]]) -> typing.Dict[str, int]: assert wb(in1=3, in2=4, in3=5) == {"a": 7, "b": 9} +def test_imperative_with_failure(): + + @dataclass + class DC: + string: typing.Optional[str] = None + + @task + def t1(a: typing.Dict[str, typing.List[int]]) -> typing.Dict[str, int]: + return {k: sum(v) for k, v in a.items()} + + @task + def t2(): + print("side effect") + + @task + def t3(dc: DC) -> DC: + if dc.string is None: + return DC(string="default") + return DC(string=dc.string + " world") # type: ignore[operator] + + wb = ImperativeWorkflow(name="my.workflow.a") + + # mapped inputs + in1 = wb.add_workflow_input("in1", int) + wb.add_workflow_input("in2", int) + in3 = wb.add_workflow_input("in3", int) + node = wb.add_entity(t1, a={"a": [in1, wb.inputs["in2"]], "b": [wb.inputs["in2"], in3]}) + wb.add_workflow_output("from_n0t1", node.outputs["o0"]) + + # pure side effect task + wb.add_entity(t2) + + failure_task = EagerFailureHandlerTask(name="sample-failure-task", inputs=wb.python_interface.inputs) + wb.add_on_failure_handler(failure_task) + + # Add a data + dc_input = wb.add_workflow_input("dc_in", DC) + node_dc = wb.add_entity(t3, dc=dc_input) + wb.add_workflow_output("updated_dc", node_dc.outputs["o0"]) + + r = wb(in1=3, in2=4, in3=5, dc_in=DC(string="hello")) + assert r.from_n0t1 == {"a": 7, "b": 9} + assert r.updated_dc.string == "hello world" + + wf_spec: WorkflowSpec = get_serializable(OrderedDict(), serialization_settings, wb) + assert len(wf_spec.template.nodes) == 3 + assert len(wf_spec.template.interface.inputs) == 4 + + node_names = [n.id for n in wf_spec.template.nodes] + assert wf_spec.template.failure_node is not None + assert wf_spec.template.failure_node.id == "efn" + + def test_imperative_with_list_io(): @task def t1(a: int) -> typing.List[int]: diff --git a/tests/flytekit/unit/core/test_serialization.py b/tests/flytekit/unit/core/test_serialization.py index 82b69859bd..f179b1eecc 100644 --- a/tests/flytekit/unit/core/test_serialization.py +++ b/tests/flytekit/unit/core/test_serialization.py @@ -549,6 +549,19 @@ def wf_with_sub_wf() -> typing.Tuple[int, int]: assert wf_with_sub_wf() == (default_val, input_val) +def test_failure_node_naming(): + @task + def t1(a: int) -> int: + return a + + @workflow + def wf(a: int) -> int: + return t1(a=a).with_overrides(node_name="efn") + + with pytest.raises(ValueError): + get_serializable(OrderedDict(), serialization_settings, wf) + + def test_default_args_task_str_type(): default_val = "" input_val = "foo"