Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Eager cleanup #3148

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
START_NODE_ID = "start-node"
END_NODE_ID = "end-node"

DEFAULT_FAILURE_NODE_ID = "efn"

# If set this environment variable overrides the default container image and the default base image in ImageSpec.
FLYTE_INTERNAL_IMAGE_ENV_VAR = "FLYTE_INTERNAL_IMAGE"

Expand Down
7 changes: 6 additions & 1 deletion flytekit/core/node_creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from typing import TYPE_CHECKING, Union

from flytekit.core.base_task import PythonTask
Expand All @@ -10,6 +11,7 @@
from flytekit.core.workflow import WorkflowBase
from flytekit.exceptions import user as _user_exceptions
from flytekit.loggers import logger
from flytekit.utils.asyn import run_sync

if TYPE_CHECKING:
from flytekit.remote.remote_callable import RemoteEntity
Expand Down Expand Up @@ -77,7 +79,10 @@ def create_node(
# When compiling, calling the entity will create a node.
ctx = FlyteContext.current_context()
if ctx.compilation_state is not None and ctx.compilation_state.mode == 1:
outputs = entity(**kwargs)
if inspect.iscoroutinefunction(entity.__call__):
outputs = run_sync(entity, **kwargs)
else:
outputs = entity(**kwargs)
Comment on lines +82 to +85
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding error handling for async execution

Consider handling potential exceptions from run_sync when executing coroutine functions. The current implementation may silently fail if the async execution encounters issues.

Code suggestion
Check the AI-generated fix before applying
Suggested change
if inspect.iscoroutinefunction(entity.__call__):
outputs = run_sync(entity, **kwargs)
else:
outputs = entity(**kwargs)
if inspect.iscoroutinefunction(entity.__call__):
try:
outputs = run_sync(entity, **kwargs)
except Exception as e:
raise RuntimeError(f"Async execution failed for {entity.name}: {str(e)}") from e
else:
outputs = entity(**kwargs)

Code Review Run #ce446d


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

# This is always the output of create_and_link_node which returns create_task_output, which can be
# VoidPromise, Promise, or our custom namedtuple of Promises.
node = ctx.compilation_state.nodes[-1]
Expand Down
139 changes: 138 additions & 1 deletion flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import inspect
import os
import signal
import time
import typing
from abc import ABC
from collections import OrderedDict
from contextlib import suppress
Expand All @@ -32,7 +34,7 @@
from flytekit.core.constants import EAGER_ROOT_ENV_NAME
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
from flytekit.core.interface import transform_function_to_interface
from flytekit.core.interface import Interface, transform_function_to_interface
from flytekit.core.promise import (
Promise,
VoidPromise,
Expand All @@ -59,11 +61,17 @@
from flytekit.models import dynamic_job as _dynamic_job
from flytekit.models import literals as _literal_models
from flytekit.models import task as task_models
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.filters import ValueIn
from flytekit.models.literals import LiteralMap
from flytekit.models.security import Secret
from flytekit.utils.asyn import loop_manager

T = TypeVar("T")

CLEANUP_LOOP_DELAY_SECONDS = 1


class PythonInstanceTask(PythonAutoContainerTask[T], ABC): # type: ignore
"""
Expand Down Expand Up @@ -549,6 +557,13 @@ def execute(self, **kwargs) -> Any:
)
raw_output = ctx.user_space_params.raw_output_prefix if ctx.user_space_params else None
logger.info(f"Constructing default remote with no config and {project}, {domain}, {raw_output}")
import os
from union._config import _get_union_api_env_var
api_value_tuple = _get_union_api_env_var()
print(f"111!!!!!!!!!!!!!!!!??????????>>>>>>>>>>>>!!!!!!!!!!! {os.environ['_UNION_EAGER_API_KEY']}",
flush=True)
print(f"111!!!!!!!!!!!!!!!!??????????>>>>!!!!! {api_value_tuple}", flush=True)

remote = get_plugin().get_remote(
config=None, project=project, domain=domain, data_upload_location=raw_output
)
Expand Down Expand Up @@ -636,3 +651,125 @@ def run(self, remote: "FlyteRemote", ss: SerializationSettings, **kwargs): # ty

with FlyteContextManager.with_context(builder):
return loop_manager.run_sync(self.async_execute, self, **kwargs)

def get_as_workflow(self):
from flytekit.core.workflow import ImperativeWorkflow

cleanup = EagerFailureHandlerTask(name=f"{self.name}-cleanup", inputs=self.python_interface.inputs)
# todo: remove this before merging
# this is actually bad, but useful for developing
cleanup._container_image = self._container_image
Comment on lines +659 to +661
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

wb = ImperativeWorkflow(name=self.name)

input_kwargs = {}
for input_name, input_python_type in self.python_interface.inputs.items():
wb.add_workflow_input(input_name, input_python_type)
input_kwargs[input_name] = wb.inputs[input_name]

node = wb.add_entity(self, **input_kwargs)
for output_name, output_python_type in self.python_interface.outputs.items():
wb.add_workflow_output(output_name, node.outputs[output_name])

wb.add_on_failure_handler(cleanup)
return wb


class EagerFailureTaskResolver(TaskResolverMixin):
@property
def location(self) -> str:
return f"{EagerFailureTaskResolver.__module__}.eager_failure_task_resolver"

def name(self) -> str:
return "eager_failure_task_resolver"

def load_task(self, loader_args: List[str]) -> Task:
"""
Given the set of identifier keys, should return one Python Task or raise an error if not found
"""
return EagerFailureHandlerTask(name="no_input_default_cleanup_task", inputs={})

def loader_args(self, settings: SerializationSettings, t: Task) -> List[str]:
"""
Return a list of strings that can help identify the parameter Task
"""
return ["eager", "failure", "handler"]

def get_all_tasks(self) -> List[Task]:
"""
Future proof method. Just making it easy to access all tasks (Not required today as we auto register them)
"""
return []


eager_failure_task_resolver = EagerFailureTaskResolver()


class EagerFailureHandlerTask(PythonAutoContainerTask, metaclass=FlyteTrackedABC):
_TASK_TYPE = "eager_failure_handler_task"

def __init__(self, name: str, inputs: typing.Optional[typing.Dict[str, typing.Type]] = None, **kwargs):
""" """
inputs = inputs or {}
super().__init__(
task_type=self._TASK_TYPE,
name=name,
interface=Interface(inputs=inputs, outputs=None),
task_config=None,
task_resolver=eager_failure_task_resolver,
**kwargs,
)

def dispatch_execute(self, ctx: FlyteContext, input_literal_map: LiteralMap) -> LiteralMap:
"""
This task should only be called during remote execution. Because when rehydrating this task at execution
time, we don't have access to the python interface of the corresponding eager task/workflow, we don't
have the Python types to convert the input literal map, but nor do we need them.
This task is responsible only for ensuring that all executions are terminated.
"""
# Recursive imports
from flytekit import current_context
from flytekit.configuration.plugin import get_plugin

most_recent = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING)
current_exec_id = current_context().execution_id
project = current_exec_id.project
domain = current_exec_id.domain
name = current_exec_id.name
logger.warning(f"Cleaning up potentially still running tasks for execution {name} in {project}/{domain}")
import os
from union._config import _get_union_api_env_var
api_value_tuple = _get_union_api_env_var()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Invalid syntax with walrus operator

There's a syntax error on line 734 where the walrus operator := is used incorrectly. The walrus operator should be used in an expression context, not as a statement. Consider using a regular assignment statement instead.

Code suggestion
Check the AI-generated fix before applying
Suggested change
api_value_tuple = _get_union_api_env_var()
api_value_tuple = _get_union_api_env_var()

Code Review Run #5f941b


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

print(f"!!!!!!!!!!!!!!!!??????????>>>>>>>>>>>>!!!!!!!!!!! {os.environ['_UNION_EAGER_API_KEY']}", flush=True)
print(f"!!!!!!!!!!!!!!!!??????????>>>>!!!!! {api_value_tuple}", flush=True)
try:
remote = get_plugin().get_remote(config=None, project=project, domain=domain)
except Exception as e:
print(e, flush=True)
import sys
sys.exit(1)
Comment on lines +746 to +749
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Blind exception catch needs specificity

Catching a blind Exception at line 746 is generally discouraged as it can mask unexpected errors. Consider catching specific exceptions or at least logging the exception details.

Code suggestion
Check the AI-generated fix before applying
Suggested change
except Exception as e:
print(e, flush=True)
import sys
sys.exit(1)
except (ConnectionError, ValueError, TypeError) as e:
logger.error(f"Failed to get remote: {e}")
import sys
logger.error("Exiting due to failure in getting remote connection")
sys.exit(1)

Code Review Run #685a40


Should Bito avoid suggestions like this for future reviews? (Manage Rules)

  • Yes, avoid them

key_filter = ValueIn("execution_tag.key", ["eager-exec"])
value_filter = ValueIn("execution_tag.value", [name])
phase_filter = ValueIn("phase", ["UNDEFINED", "QUEUED", "RUNNING"])
# This should be made more robust, currently lacking retries and exception handling
while True:
exec_models, _ = remote.client.list_executions_paginated(
project,
domain,
limit=100,
filters=[key_filter, value_filter, phase_filter],
sort_by=most_recent,
)
logger.warning(f"Found {len(exec_models)} executions this round for termination")
if not exec_models:
break
logger.warning(exec_models)
for exec_model in exec_models:
logger.warning(f"Terminating execution {exec_model.id}, phase {exec_model.closure.phase}")
remote.client.terminate_execution(exec_model.id, f"clean up by parent eager execution {name}")
time.sleep(CLEANUP_LOOP_DELAY_SECONDS)

# Just echo back
return input_literal_map

def execute(self, **kwargs) -> Any:
raise AssertionError("this task shouldn't need to call execute")
34 changes: 34 additions & 0 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,40 @@ def add_launch_plan(self, launch_plan: _annotated_launch_plan.LaunchPlan, **kwar
def add_subwf(self, sub_wf: WorkflowBase, **kwargs) -> Node:
return self.add_entity(sub_wf, **kwargs)

def add_on_failure_handler(self, entity):
"""
This is a special function that mimics the add_entity function, but this is only used
to add the failure node. Failure nodes are special because we don't want
them to be part of the main workflow.
"""
from flytekit.core.node_creation import create_node

ctx = FlyteContext.current_context()
if ctx.compilation_state is not None:
raise RuntimeError("Can't already be compiling")
with FlyteContextManager.with_context(ctx.with_compilation_state(self.compilation_state)) as ctx:
if entity.python_interface and self.python_interface:
workflow_inputs = self.python_interface.inputs
failure_node_inputs = entity.python_interface.inputs

# Workflow inputs should be a subset of failure node inputs.
if (failure_node_inputs | workflow_inputs) != failure_node_inputs:
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)
additional_keys = failure_node_inputs.keys() - workflow_inputs.keys()
# Raising an error if the additional inputs in the failure node are not optional.
for k in additional_keys:
if not is_optional_type(failure_node_inputs[k]):
raise FlyteFailureNodeInputMismatchException(self.on_failure, self)

n = create_node(entity=entity, **self._inputs)
# Maybe this can be cleaned up, but the create node function creates a node
# and add it to the compilation state. We need to pop it off because we don't
# want it in the actual workflow.
ctx.compilation_state.nodes.pop(-1)
self._failure_node = n
n._id = _common_constants.DEFAULT_FAILURE_NODE_ID
return n # type: ignore

def ready(self) -> bool:
"""
This function returns whether or not the workflow is in a ready state, which means
Expand Down
6 changes: 6 additions & 0 deletions flytekit/models/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def from_python_std(cls, string):
return Contains._parse_from_string(string)
elif string.startswith("value_in("):
return ValueIn._parse_from_string(string)
elif string.startswith("value_not_in("):
return ValueNotIn._parse_from_string(string)
else:
raise ValueError("'{}' could not be parsed into a filter.".format(string))

Expand Down Expand Up @@ -133,3 +135,7 @@ class Contains(SetFilter):

class ValueIn(SetFilter):
_comparator = "value_in"


class ValueNotIn(SetFilter):
_comparator = "value_not_in"
6 changes: 6 additions & 0 deletions flytekit/tools/serialize_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flytekit import LaunchPlan
from flytekit.core import context_manager as flyte_context
from flytekit.core.base_task import PythonTask
from flytekit.core.python_function_task import EagerAsyncPythonFunctionTask
from flytekit.core.workflow import WorkflowBase
from flytekit.models import launch_plan as _launch_plan_models
from flytekit.models import task as task_models
Expand Down Expand Up @@ -60,6 +61,11 @@ def get_registrable_entities(
lp = LaunchPlan.get_default_launch_plan(ctx, entity)
get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options)

if isinstance(entity, EagerAsyncPythonFunctionTask):
wf = entity.get_as_workflow()
lp = LaunchPlan.get_default_launch_plan(ctx, wf)
get_serializable(new_api_serializable_entities, ctx.serialization_settings, lp, options)

new_api_model_values = list(new_api_serializable_entities.values())
entities_to_be_serialized = list(filter(_should_register_with_admin, new_api_model_values))

Expand Down
4 changes: 4 additions & 0 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def get_serializable_workflow(
if n.id == _common_constants.GLOBAL_INPUT_NODE_ID:
continue

# Ensure no node is named the failure node id
if n.id == _common_constants.DEFAULT_FAILURE_NODE_ID:
raise ValueError(f"Node {n.id} is reserved for the failure node")

# Recursively serialize the node
serialized_nodes.append(get_serializable(entity_mapping, settings, n, options))

Expand Down
33 changes: 33 additions & 0 deletions tests/flytekit/unit/core/test_eager_cleanup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections import OrderedDict

import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
from flytekit.core.python_function_task import EagerFailureHandlerTask
from flytekit.tools.translator import get_serializable

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


def test_failure():
t = EagerFailureHandlerTask(name="tester", inputs={"a": int})

spec = get_serializable(OrderedDict(), serialization_settings, t)
print(spec)

assert spec.template.container.args == ['pyflyte-execute', '--inputs', '{{.input}}', '--output-prefix', '{{.outputPrefix}}', '--raw-output-data-prefix', '{{.rawOutputDataPrefix}}', '--checkpoint-path', '{{.checkpointOutputPrefix}}', '--prev-checkpoint', '{{.prevCheckpointPrefix}}', '--resolver', 'flytekit.core.python_function_task.eager_failure_task_resolver', '--', 'eager', 'failure', 'handler']


def test_loading():
from flytekit.tools.module_loader import load_object_from_module

resolver = load_object_from_module("flytekit.core.python_function_task.eager_failure_task_resolver")
print(resolver)
t = resolver.load_task([])
assert isinstance(t, EagerFailureHandlerTask)
Loading
Loading