Skip to content

Commit

Permalink
fix sync for map task in dynamic (#3141)
Browse files Browse the repository at this point in the history
fix sync for map task in dynamic

Signed-off-by: Troy Chiu <[email protected]>
  • Loading branch information
troychiu authored Feb 21, 2025
1 parent 74cdcfb commit d7a3cef
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 8 deletions.
41 changes: 36 additions & 5 deletions flytekit/remote/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,28 @@ def promote_from_model(cls, model: _workflow_model.GateNode):


class FlyteArrayNode(_workflow_model.ArrayNode):
def __init__(
self,
flyte_node: FlyteNode,
parallelism: int,
min_successes: int,
min_success_ratio: float,
):
super().__init__(flyte_node, parallelism, min_successes, min_success_ratio)
self._flyte_node = flyte_node

@property
def flyte_node(self) -> FlyteNode:
return self._flyte_node

@classmethod
def promote_from_model(cls, model: _workflow_model.ArrayNode):
def promote_from_model(
cls,
model: _workflow_model.ArrayNode,
flyte_node: FlyteNode,
):
return cls(
node=model._node,
flyte_node=flyte_node,
parallelism=model._parallelism,
min_successes=model._min_successes,
min_success_ratio=model._min_success_ratio,
Expand Down Expand Up @@ -406,7 +424,7 @@ def task_node(self) -> Optional[FlyteTaskNode]:
return self._flyte_task_node

@property
def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode]:
def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode, FlyteArrayNode]:
return self._flyte_entity

@classmethod
Expand Down Expand Up @@ -477,8 +495,21 @@ def promote_from_model(
elif model.gate_node is not None:
flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node)
elif model.array_node is not None:
flyte_array_node = FlyteArrayNode.promote_from_model(model.array_node)
# TODO: validate task in tasks
if model.array_node.node is None:
raise _system_exceptions.FlyteSystemException(
f"Bad Node model, array node detected but no node specified, node: {model}"
)
flyte_node, converted_sub_workflows = cls.promote_from_model(
model.array_node.node,
sub_workflows,
node_launch_plans,
tasks,
converted_sub_workflows,
)
flyte_array_node = FlyteArrayNode.promote_from_model(
model.array_node,
flyte_node,
)
else:
raise _system_exceptions.FlyteSystemException(
f"Bad Node model, neither task nor workflow detected, node: {model}"
Expand Down
5 changes: 2 additions & 3 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2615,10 +2615,9 @@ def sync_node_execution(
if execution._node.array_node is None:
logger.error("Array node not found")
return execution
# if there's a task node underneath the array node, let's fetch the interface for it
# if there's a task node underneath the array node
if execution._node.array_node.node.task_node is not None:
tid = execution._node.array_node.node.task_node.reference_id
t = self.fetch_task(tid.project, tid.domain, tid.name, tid.version)
t = execution._node.flyte_entity.flyte_node.task_node.flyte_task
execution._task_executions = [
self.sync_task_execution(FlyteTaskExecution.promote_from_model(task_execution), t)
for task_execution in iterate_task_executions(self.client, execution.id)
Expand Down
17 changes: 17 additions & 0 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,6 +665,23 @@ def test_execute_workflow_with_maptask(register):
assert execution.outputs["o0"] == [4, 5, 6]
assert len(execution.node_executions["n0"].task_executions) == 1

def test_execution_workflow_with_maptask_in_dynamic(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
d: typing.List[int] = [1, 2, 3]
flyte_launch_plan = remote.fetch_launch_plan(name="basic.dynamic_array_map.workflow_with_maptask_in_dynamic", version=VERSION)
execution = remote.execute(
flyte_launch_plan,
inputs={"data": d},
version=VERSION,
wait=True,
)
assert execution.outputs["o0"] == [2, 3, 4]
assert "n0" in execution.node_executions
assert execution.node_executions["n0"].subworkflow_node_executions is not None
assert "n0-0-dn0" in execution.node_executions["n0"].subworkflow_node_executions
assert len(execution.node_executions["n0"].subworkflow_node_executions["n0-0-dn0"].task_executions) == 1


def test_executes_nested_workflow_dictating_interruptible(register):
remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN)
flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from flytekit import workflow, task, map_task, dynamic


@task
def fn(x: int) -> int:
return x + 1


@dynamic
def dynamic(data: list[int]) -> list[int]:
return map_task(fn)(x=data)


@workflow
def workflow_with_maptask_in_dynamic(data: list[int]) -> list[int]:
return dynamic(data=data)

0 comments on commit d7a3cef

Please sign in to comment.