diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index fd78d4c3c4..7c6f72ec1b 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -348,10 +348,28 @@ def promote_from_model(cls, model: _workflow_model.GateNode): class FlyteArrayNode(_workflow_model.ArrayNode): + def __init__( + self, + flyte_node: FlyteNode, + parallelism: int, + min_successes: int, + min_success_ratio: float, + ): + super().__init__(flyte_node, parallelism, min_successes, min_success_ratio) + self._flyte_node = flyte_node + + @property + def flyte_node(self) -> FlyteNode: + return self._flyte_node + @classmethod - def promote_from_model(cls, model: _workflow_model.ArrayNode): + def promote_from_model( + cls, + model: _workflow_model.ArrayNode, + flyte_node: FlyteNode, + ): return cls( - node=model._node, + flyte_node=flyte_node, parallelism=model._parallelism, min_successes=model._min_successes, min_success_ratio=model._min_success_ratio, @@ -406,7 +424,7 @@ def task_node(self) -> Optional[FlyteTaskNode]: return self._flyte_task_node @property - def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode]: + def flyte_entity(self) -> Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan, FlyteBranchNode, FlyteArrayNode]: return self._flyte_entity @classmethod @@ -477,8 +495,21 @@ def promote_from_model( elif model.gate_node is not None: flyte_gate_node = FlyteGateNode.promote_from_model(model.gate_node) elif model.array_node is not None: - flyte_array_node = FlyteArrayNode.promote_from_model(model.array_node) - # TODO: validate task in tasks + if model.array_node.node is None: + raise _system_exceptions.FlyteSystemException( + f"Bad Node model, array node detected but no node specified, node: {model}" + ) + flyte_node, converted_sub_workflows = cls.promote_from_model( + model.array_node.node, + sub_workflows, + node_launch_plans, + tasks, + converted_sub_workflows, + ) + flyte_array_node = FlyteArrayNode.promote_from_model( + model.array_node, + flyte_node, + ) else: raise _system_exceptions.FlyteSystemException( f"Bad Node model, neither task nor workflow detected, node: {model}" diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index ef8b28d866..804326f79e 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -2615,10 +2615,9 @@ def sync_node_execution( if execution._node.array_node is None: logger.error("Array node not found") return execution - # if there's a task node underneath the array node, let's fetch the interface for it + # if there's a task node underneath the array node if execution._node.array_node.node.task_node is not None: - tid = execution._node.array_node.node.task_node.reference_id - t = self.fetch_task(tid.project, tid.domain, tid.name, tid.version) + t = execution._node.flyte_entity.flyte_node.task_node.flyte_task execution._task_executions = [ self.sync_task_execution(FlyteTaskExecution.promote_from_model(task_execution), t) for task_execution in iterate_task_executions(self.client, execution.id) diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 6b3d8ee855..51e79ff0e9 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -665,6 +665,23 @@ def test_execute_workflow_with_maptask(register): assert execution.outputs["o0"] == [4, 5, 6] assert len(execution.node_executions["n0"].task_executions) == 1 +def test_execution_workflow_with_maptask_in_dynamic(register): + remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) + d: typing.List[int] = [1, 2, 3] + flyte_launch_plan = remote.fetch_launch_plan(name="basic.dynamic_array_map.workflow_with_maptask_in_dynamic", version=VERSION) + execution = remote.execute( + flyte_launch_plan, + inputs={"data": d}, + version=VERSION, + wait=True, + ) + assert execution.outputs["o0"] == [2, 3, 4] + assert "n0" in execution.node_executions + assert execution.node_executions["n0"].subworkflow_node_executions is not None + assert "n0-0-dn0" in execution.node_executions["n0"].subworkflow_node_executions + assert len(execution.node_executions["n0"].subworkflow_node_executions["n0-0-dn0"].task_executions) == 1 + + def test_executes_nested_workflow_dictating_interruptible(register): remote = FlyteRemote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN) flyte_launch_plan = remote.fetch_launch_plan(name="basic.child_workflow.parent_wf", version=VERSION) diff --git a/tests/flytekit/integration/remote/workflows/basic/dynamic_array_map.py b/tests/flytekit/integration/remote/workflows/basic/dynamic_array_map.py new file mode 100644 index 0000000000..12407b095a --- /dev/null +++ b/tests/flytekit/integration/remote/workflows/basic/dynamic_array_map.py @@ -0,0 +1,16 @@ +from flytekit import workflow, task, map_task, dynamic + + +@task +def fn(x: int) -> int: + return x + 1 + + +@dynamic +def dynamic(data: list[int]) -> list[int]: + return map_task(fn)(x=data) + + +@workflow +def workflow_with_maptask_in_dynamic(data: list[int]) -> list[int]: + return dynamic(data=data)