Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flytekit: Rename map_task to map, replace min_successes and min_success_ratio with tolerance, rename max_parallelism to concurrency #3107

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
11 changes: 9 additions & 2 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@

import os
import sys
import warnings
from typing import Generator

from rich import traceback
Expand All @@ -219,10 +220,9 @@
else:
from importlib.metadata import entry_points


from flytekit._version import __version__
from flytekit.configuration import Config
from flytekit.core.array_node_map_task import map_task
from flytekit.core.array_node_map_task import map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider maintaining backward compatibility for imports

Consider keeping both map_task and map imports to maintain backward compatibility. The alias is defined later but importing directly as map may break existing code that uses map_task.

Code suggestion
Check the AI-generated fix before applying
Suggested change
from flytekit.core.array_node_map_task import map
from flytekit.core.array_node_map_task import map_task

Code Review Run #d47fe6


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

from flytekit.core.artifact import Artifact
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider deprecation notice for map_task rename

Consider keeping the original map_task import and marking it as deprecated using @deprecated decorator if this is an API change, to maintain backward compatibility. The alias on line 277 may not be sufficient for all use cases.

Code suggestion
Check the AI-generated fix before applying
Suggested change
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
from flytekit.core.array_node_map_task import map_task, map
from deprecated import deprecated

Code Review Run #cbd7b1


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Expand Down Expand Up @@ -269,6 +269,13 @@
StructuredDatasetType,
)

warnings.warn(
"'map_task' is deprecated and will be removed in a future version. Use 'map' instead.",
DeprecationWarning,
stacklevel=2,
)
map_task = map


def current_context() -> ExecutionParameters:
"""
Expand Down
16 changes: 13 additions & 3 deletions flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,16 +255,26 @@ class RunLevelParams(PyFlyteParams):
),
)
)
max_parallelism: int = make_click_option_field(

concurrency: int = make_click_option_field(
click.Option(
param_decls=["--max-parallelism"],
param_decls=["--concurrency"],
required=False,
type=int,
show_default=True,
help="Number of nodes of a workflow that can be executed in parallel. If not specified,"
" project/domain defaults are used. If 0 then it is unlimited.",
)
)
max_parallelism: int = make_click_option_field(
click.Option(
param_decls=["--max-parallelism"],
required=False,
type=int,
show_default=True,
help="[Deprecated] Use --concurrency instead",
)
)
Comment on lines +269 to +277
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider removing deprecated max-parallelism option

Consider removing the deprecated --max-parallelism option since --concurrency is now the preferred way to control parallel execution. Having both options may cause confusion for users.

Code suggestion
Check the AI-generated fix before applying
Suggested change
max_parallelism: int = make_click_option_field(
click.Option(
param_decls=["--max-parallelism"],
required=False,
type=int,
show_default=True,
help="[Deprecated] Use --concurrency instead",
)
)

Code Review Run #cbd7b1


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

disable_notifications: bool = make_click_option_field(
click.Option(
param_decls=["--disable-notifications"],
Expand Down Expand Up @@ -516,7 +526,7 @@ def options_from_run_params(run_level_params: RunLevelParams) -> Options:
raw_output_data_config=RawOutputDataConfig(output_location_prefix=run_level_params.raw_output_data_prefix)
if run_level_params.raw_output_data_prefix
else None,
max_parallelism=run_level_params.max_parallelism,
concurrency=run_level_params.max_parallelism,
disable_notifications=run_level_params.disable_notifications,
security_context=security.SecurityContext(
run_as=security.Identity(k8s_service_account=run_level_params.service_account)
Expand Down
48 changes: 39 additions & 9 deletions flytekit/core/array_node_map_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hashlib
import math
import os # TODO: use flytekit logger
import warnings
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast

Expand Down Expand Up @@ -369,11 +370,12 @@ def _raw_execute(self, **kwargs) -> Any:
return outputs


def map_task(
def map(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider keeping descriptive function name

Consider keeping the original function name map_task instead of renaming to map as it could conflict with Python's built-in map function and cause confusion. The original name was more descriptive of the function's purpose.

Code suggestion
Check the AI-generated fix before applying
Suggested change
def map(
def map_task(

Code Review Run #d47fe6


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"],
concurrency: Optional[int] = None,
min_successes: Optional[int] = None,
min_success_ratio: float = 1.0,
tolerance: Optional[Union[float, int]] = None,
min_successes: Optional[int] = None, # Deprecated
min_success_ratio: Optional[float] = None, # Deprecated
**kwargs,
):
"""
Expand All @@ -385,23 +387,51 @@ def map_task(
size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until
all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the
array node will inherit parallelism from the workflow
:param min_successes: The minimum number of successful executions
:param min_success_ratio: The minimum ratio of successful executions
:param tolerance: Failure tolerance threshold.
If float (0-1): represents minimum success ratio
If int (>1): represents minimum number of successes
:param min_successes: The minimum number of successful executions [Deprecated] Use tolerance instead
:param min_success_ratio: The minimum ratio of successful executions [Deprecated] Use tolerance instead
"""
from flytekit.remote import FlyteLaunchPlan

if min_successes is not None and min_success_ratio != 1.0:
warnings.warn(
"min_success and min_success_ratio are deprecated. Please use 'tolerance' parameter instead",
DeprecationWarning,
stacklevel=2,
)

computed_min_ratio = 1.0
computed_min_success = None

if tolerance is not None:
if isinstance(tolerance, float):
if not 0 <= tolerance <= 1:
raise ValueError("tolerance must be between 0 and 1")
computed_min_ratio = tolerance
elif isinstance(tolerance, int):
if tolerance < 1:
raise ValueError("tolerance must be greater than 0")
computed_min_success = tolerance
else:
raise TypeError("tolerance must be float or int")

final_min_ratio = computed_min_ratio if min_success_ratio is None else min_success_ratio
final_min_successes = computed_min_success if min_successes is None else min_successes

if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)):
return array_node(
target=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
min_successes=final_min_successes,
min_success_ratio=final_min_ratio,
)
return array_node_map_task(
task_function=target,
concurrency=concurrency,
min_successes=min_successes,
min_success_ratio=min_success_ratio,
min_successes=final_min_successes,
min_success_ratio=final_min_ratio,
**kwargs,
)

Expand Down
55 changes: 43 additions & 12 deletions flytekit/core/launch_plan.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
import warnings
from typing import Any, Callable, Dict, List, Optional, Type

from flytekit.core import workflow as _annotated_workflow
Expand Down Expand Up @@ -129,7 +130,8 @@ def create(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
max_parallelism: Optional[int] = None, # Deprecated: Use concurrency instead
concurrency: Optional[int] = None,
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
Expand Down Expand Up @@ -183,7 +185,8 @@ def create(
labels=labels,
annotations=annotations,
raw_output_data_config=raw_output_data_config,
max_parallelism=max_parallelism,
concurrency=concurrency, # Pass new parameter
max_parallelism=max_parallelism, # Pass deprecated parameter
security_context=security_context,
trigger=trigger,
overwrite_cache=overwrite_cache,
Expand Down Expand Up @@ -213,7 +216,8 @@ def get_or_create(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
concurrency: Optional[int] = None,
max_parallelism: Optional[int] = None, # Deprecated
security_context: Optional[security.SecurityContext] = None,
auth_role: Optional[_common_models.AuthRole] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
Expand Down Expand Up @@ -241,9 +245,10 @@ def get_or_create(
:param annotations: Optional annotations to attach to executions created by this launch plan.
:param raw_output_data_config: Optional location of offloaded data for things like S3, etc.
:param auth_role: Add an auth role if necessary.
:param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param concurrency: Controls the maximum number of tasknodes that can be run in parallel for the entire
workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and
parallelism/concurrency of MapTasks is independent from this.
:param max_parallelism: [Deprecated] Use concurrency instead.
:param trigger: [alpha] This is a new syntax for specifying schedules.
:param overwrite_cache: If set to True, the execution will always overwrite cache
:param auto_activate: If set to True, the launch plan will be activated automatically on registration.
Expand All @@ -258,6 +263,7 @@ def get_or_create(
or annotations is not None
or raw_output_data_config is not None
or auth_role is not None
or concurrency is not None
or max_parallelism is not None
or security_context is not None
or trigger is not None
Expand Down Expand Up @@ -296,7 +302,11 @@ def get_or_create(
("labels", labels, cached_outputs["_labels"]),
("annotations", annotations, cached_outputs["_annotations"]),
("raw_output_data_config", raw_output_data_config, cached_outputs["_raw_output_data_config"]),
("max_parallelism", max_parallelism, cached_outputs["_max_parallelism"]),
(
"concurrency",
concurrency if concurrency is not None else max_parallelism,
cached_outputs.get("_concurrency", cached_outputs.get("")),
Comment on lines +307 to +308
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible incorrect dictionary key lookup

The cached output lookup for concurrency appears to have an incomplete key in the dictionary get operation. The second get() call is missing its key parameter which could lead to unexpected behavior. Consider fixing the nested get calls.

Code suggestion
Check the AI-generated fix before applying
 -                    cached_outputs.get("_concurrency", cached_outputs.get(""))
 +                    cached_outputs.get("_concurrency", cached_outputs.get("_max_parallelism"))

Code Review Run #351655


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

),
("security_context", security_context, cached_outputs["_security_context"]),
("overwrite_cache", overwrite_cache, cached_outputs["_overwrite_cache"]),
("auto_activate", auto_activate, cached_outputs["_auto_activate"]),
Expand Down Expand Up @@ -326,7 +336,8 @@ def get_or_create(
labels,
annotations,
raw_output_data_config,
max_parallelism,
concurrency=concurrency,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor init method signature

The 'init' method has too many parameters (14 > 5) and is missing docstring and return type annotation.

Code suggestion
Check the AI-generated fix before applying
 -    def __init__(
 -        self,
 -        name: str,
 -        workflow: _annotated_workflow.WorkflowBase,
 -        parameters: _interface_models.ParameterMap,
 -        fixed_inputs: _literal_models.LiteralMap,
 -        schedule: Optional[_schedule_model.Schedule] = None,
 -        notifications: Optional[List[_common_models.Notification]] = None,
 -        labels: Optional[_common_models.Labels] = None,
 -        annotations: Optional[_common_models.Annotations] = None,
 -        raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
 -        max_parallelism: Optional[int] = None,
 -        security_context: Optional[security.SecurityContext] = None,
 -        trigger: Optional[LaunchPlanTriggerBase] = None,
 -        overwrite_cache: Optional[bool] = None,
 -        auto_activate: bool = False,
 -    ):
 +    @dataclass
 +    class Config:
 +        """Configuration for LaunchPlan initialization."""
 +        name: str
 +        workflow: _annotated_workflow.WorkflowBase
 +        parameters: _interface_models.ParameterMap
 +        fixed_inputs: _literal_models.LiteralMap
 +        schedule: _schedule_model.Schedule | None = None
 +        notifications: list[_common_models.Notification] | None = None
 +        labels: _common_models.Labels | None = None
 +        annotations: _common_models.Annotations | None = None
 +        raw_output_data_config: _common_models.RawOutputDataConfig | None = None
 +        max_parallelism: int | None = None
 +        security_context: security.SecurityContext | None = None
 +        trigger: LaunchPlanTriggerBase | None = None
 +        overwrite_cache: bool | None = None
 +        auto_activate: bool = False
 +
 +    def __init__(self, config: Config) -> None:

Code Review Run #99b31d


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

max_parallelism=max_parallelism,
auth_role=auth_role,
security_context=security_context,
trigger=trigger,
Expand All @@ -347,7 +358,8 @@ def __init__(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
concurrency: Optional[int] = None,
max_parallelism: Optional[int] = None, # Deprecated
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
Expand All @@ -367,7 +379,14 @@ def __init__(
self._labels = labels
self._annotations = annotations
self._raw_output_data_config = raw_output_data_config
self._max_parallelism = max_parallelism
self._concurrency = concurrency
self._max_parallelism = concurrency if concurrency is not None else max_parallelism
if max_parallelism is not None:
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
self._security_context = security_context
self._trigger = trigger
self._overwrite_cache = overwrite_cache
Expand All @@ -385,7 +404,8 @@ def clone_with(
labels: Optional[_common_models.Labels] = None,
annotations: Optional[_common_models.Annotations] = None,
raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None,
max_parallelism: Optional[int] = None,
concurrency: Optional[int] = None,
max_parallelism: Optional[int] = None, # Dreprecated
security_context: Optional[security.SecurityContext] = None,
trigger: Optional[LaunchPlanTriggerBase] = None,
overwrite_cache: Optional[bool] = None,
Expand All @@ -401,6 +421,7 @@ def clone_with(
labels=labels or self.labels,
annotations=annotations or self.annotations,
raw_output_data_config=raw_output_data_config or self.raw_output_data_config,
concurrency=concurrency or self.concurrency,
max_parallelism=max_parallelism or self.max_parallelism,
security_context=security_context or self.security_context,
trigger=trigger,
Expand Down Expand Up @@ -466,7 +487,17 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig]

@property
def max_parallelism(self) -> Optional[int]:
return self._max_parallelism
"""[Deprecated] Use concurrency instead. This property is maintained for backward compatibility"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
return self._concurrency

@property
def concurrency(self) -> Optional[int]:
return self._concurrency

@property
def security_context(self) -> Optional[security.SecurityContext]:
Expand Down
55 changes: 42 additions & 13 deletions flytekit/core/options.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import typing
import warnings
from dataclasses import dataclass

from flytekit.models import common as common_models
Expand All @@ -8,33 +9,61 @@
@dataclass
class Options(object):
"""
These are options that can be configured for a launchplan during registration or overridden during an execution.
For instance two people may want to run the same workflow but have the offloaded data stored in two different
These are options that can be configured for a launch plan during registration or overridden during an execution.
For instance, two people may want to run the same workflow but have the offloaded data stored in two different
buckets. Or you may want labels or annotations to be different. This object is used when launching an execution
in a Flyte backend, and also when registering launch plans.

Args:
labels: Custom labels to be applied to the execution resource
annotations: Custom annotations to be applied to the execution resource
security_context: Indicates security context for permissions triggered with this launch plan
raw_output_data_config: Optional location of offloaded data for things like S3, etc.
remote prefix for storage location of the form ``s3://<bucket>/key...`` or
``gcs://...`` or ``file://...``. If not specified will use the platform configured default. This is where
Attributes:
labels (typing.Optional[common_models.Labels]): Custom labels to be applied to the execution resource.
annotations (typing.Optional[common_models.Annotations]): Custom annotations to be applied to the execution resource.
raw_output_data_config (typing.Optional[common_models.RawOutputDataConfig]): Optional location of offloaded data
for things like S3, etc. Remote prefix for storage location of the form ``s3://<bucket>/key...`` or
``gcs://...`` or ``file://...``. If not specified, will use the platform-configured default. This is where
the data for offloaded types is stored.
max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow.
notifications: List of notifications for this execution.
disable_notifications: This should be set to true if all notifications are intended to be disabled for this execution.
security_context (typing.Optional[security.SecurityContext]): Indicates security context for permissions triggered
with this launch plan.
concurrency (typing.Optional[int]): Controls the maximum number of task nodes that can be run in parallel for the
entire workflow.
notifications (typing.Optional[typing.List[common_models.Notification]]): List of notifications for this execution.
disable_notifications (typing.Optional[bool]): Set to True if all notifications are intended to be disabled
for this execution.
overwrite_cache (typing.Optional[bool]): When set to True, forces the execution to overwrite any existing cached values.
"""

labels: typing.Optional[common_models.Labels] = None
annotations: typing.Optional[common_models.Annotations] = None
raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None
security_context: typing.Optional[security.SecurityContext] = None
max_parallelism: typing.Optional[int] = None
concurrency: typing.Optional[int] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding constructor deprecation warning

The parameter max_parallelism has been renamed to concurrency. While backward compatibility is maintained through property and setter methods, consider adding a deprecation warning in the constructor when max_parallelism is used.

Code suggestion
Check the AI-generated fix before applying
 -    def __init__(self, **kwargs):
 +    def __init__(self, max_parallelism=None, **kwargs):
 +        if max_parallelism is not None:
 +            warnings.warn(
 +                "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
 +                DeprecationWarning,
 +                stacklevel=2)
 +        super().__init__(**kwargs)

Code Review Run #cbd7b1


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding concurrency parameter validation

Consider adding validation for the concurrency parameter to ensure it's a positive integer when set. A negative or zero value for concurrency could cause unexpected behavior.

Code suggestion
Check the AI-generated fix before applying
Suggested change
concurrency: typing.Optional[int] = None
_concurrency: typing.Optional[int] = None
@property
def concurrency(self) -> typing.Optional[int]:
return self._concurrency
@concurrency.setter
def concurrency(self, value: typing.Optional[int]):
if value is not None and value <= 0:
raise ValueError('concurrency must be a positive integer')
self._concurrency = value

Code Review Run #351655


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

notifications: typing.Optional[typing.List[common_models.Notification]] = None
disable_notifications: typing.Optional[bool] = None
overwrite_cache: typing.Optional[bool] = None

@property
def max_parallelism(self) -> typing.Optional[int]:
"""
[Deprecated] Use concurrency instead. This property is maintained for backward compatibility
"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
return self.concurrency

@max_parallelism.setter
def max_parallelism(self, value: typing.Optional[int]):
"""
Setter for max_parallelism (deprecated in favor of concurrency)
"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
self.concurrency = value
Comment on lines +43 to +65
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider using standard deprecation decorator pattern

Consider using a decorator like deprecated from the warnings module instead of manually implementing deprecation warnings. This would make the code more maintainable and consistent with Python's standard deprecation patterns.

Code suggestion
Check the AI-generated fix before applying
Suggested change
@property
def max_parallelism(self) -> typing.Optional[int]:
"""
[Deprecated] Use concurrency instead. This property is maintained for backward compatibility
"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
return self.concurrency
@max_parallelism.setter
def max_parallelism(self, value: typing.Optional[int]):
"""
Setter for max_parallelism (deprecated in favor of concurrency)
"""
warnings.warn(
"max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.",
DeprecationWarning,
stacklevel=2,
)
self.concurrency = value
@property
@deprecated("Use concurrency instead", DeprecationWarning)
def max_parallelism(self) -> typing.Optional[int]:
return self.concurrency
@max_parallelism.setter
@deprecated("Use concurrency instead", DeprecationWarning)
def max_parallelism(self, value: typing.Optional[int]):
self.concurrency = value

Code Review Run #cbd7b1


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged


@classmethod
def default_from(
cls,
Expand Down
Loading
Loading