diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 74a50aff7f..aaf83f8d8e 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -208,6 +208,7 @@ import os import sys +import warnings from typing import Generator from rich import traceback @@ -219,10 +220,9 @@ else: from importlib.metadata import entry_points - from flytekit._version import __version__ from flytekit.configuration import Config -from flytekit.core.array_node_map_task import map_task +from flytekit.core.array_node_map_task import map from flytekit.core.artifact import Artifact from flytekit.core.base_sql_task import SQLTask from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes @@ -269,6 +269,13 @@ StructuredDatasetType, ) +warnings.warn( + "'map_task' is deprecated and will be removed in a future version. Use 'map' instead.", + DeprecationWarning, + stacklevel=2, +) +map_task = map + def current_context() -> ExecutionParameters: """ diff --git a/flytekit/clis/sdk_in_container/run.py b/flytekit/clis/sdk_in_container/run.py index 7a08ef31af..2f9911cf2d 100644 --- a/flytekit/clis/sdk_in_container/run.py +++ b/flytekit/clis/sdk_in_container/run.py @@ -255,9 +255,10 @@ class RunLevelParams(PyFlyteParams): ), ) ) - max_parallelism: int = make_click_option_field( + + concurrency: int = make_click_option_field( click.Option( - param_decls=["--max-parallelism"], + param_decls=["--concurrency"], required=False, type=int, show_default=True, @@ -265,6 +266,15 @@ class RunLevelParams(PyFlyteParams): " project/domain defaults are used. If 0 then it is unlimited.", ) ) + max_parallelism: int = make_click_option_field( + click.Option( + param_decls=["--max-parallelism"], + required=False, + type=int, + show_default=True, + help="[Deprecated] Use --concurrency instead", + ) + ) disable_notifications: bool = make_click_option_field( click.Option( param_decls=["--disable-notifications"], @@ -516,7 +526,7 @@ def options_from_run_params(run_level_params: RunLevelParams) -> Options: raw_output_data_config=RawOutputDataConfig(output_location_prefix=run_level_params.raw_output_data_prefix) if run_level_params.raw_output_data_prefix else None, - max_parallelism=run_level_params.max_parallelism, + concurrency=run_level_params.max_parallelism, disable_notifications=run_level_params.disable_notifications, security_context=security.SecurityContext( run_as=security.Identity(k8s_service_account=run_level_params.service_account) diff --git a/flytekit/core/array_node_map_task.py b/flytekit/core/array_node_map_task.py index 05690e175b..b30039cb1c 100644 --- a/flytekit/core/array_node_map_task.py +++ b/flytekit/core/array_node_map_task.py @@ -3,6 +3,7 @@ import hashlib import math import os # TODO: use flytekit logger +import warnings from contextlib import contextmanager from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union, cast @@ -369,11 +370,12 @@ def _raw_execute(self, **kwargs) -> Any: return outputs -def map_task( +def map( target: Union[LaunchPlan, PythonFunctionTask, "FlyteLaunchPlan"], concurrency: Optional[int] = None, - min_successes: Optional[int] = None, - min_success_ratio: float = 1.0, + tolerance: Optional[Union[float, int]] = None, + min_successes: Optional[int] = None, # Deprecated + min_success_ratio: Optional[float] = None, # Deprecated **kwargs, ): """ @@ -385,23 +387,51 @@ def map_task( size. If the size of the input exceeds the concurrency value, then multiple batches will be run serially until all inputs are processed. If set to 0, this means unbounded concurrency. If left unspecified, this means the array node will inherit parallelism from the workflow - :param min_successes: The minimum number of successful executions - :param min_success_ratio: The minimum ratio of successful executions + :param tolerance: Failure tolerance threshold. + If float (0-1): represents minimum success ratio + If int (>1): represents minimum number of successes + :param min_successes: The minimum number of successful executions [Deprecated] Use tolerance instead + :param min_success_ratio: The minimum ratio of successful executions [Deprecated] Use tolerance instead """ from flytekit.remote import FlyteLaunchPlan + if min_successes is not None and min_success_ratio != 1.0: + warnings.warn( + "min_success and min_success_ratio are deprecated. Please use 'tolerance' parameter instead", + DeprecationWarning, + stacklevel=2, + ) + + computed_min_ratio = 1.0 + computed_min_success = None + + if tolerance is not None: + if isinstance(tolerance, float): + if not 0 <= tolerance <= 1: + raise ValueError("tolerance must be between 0 and 1") + computed_min_ratio = tolerance + elif isinstance(tolerance, int): + if tolerance < 1: + raise ValueError("tolerance must be greater than 0") + computed_min_success = tolerance + else: + raise TypeError("tolerance must be float or int") + + final_min_ratio = computed_min_ratio if min_success_ratio is None else min_success_ratio + final_min_successes = computed_min_success if min_successes is None else min_successes + if isinstance(target, (LaunchPlan, FlyteLaunchPlan, ReferenceTask)): return array_node( target=target, concurrency=concurrency, - min_successes=min_successes, - min_success_ratio=min_success_ratio, + min_successes=final_min_successes, + min_success_ratio=final_min_ratio, ) return array_node_map_task( task_function=target, concurrency=concurrency, - min_successes=min_successes, - min_success_ratio=min_success_ratio, + min_successes=final_min_successes, + min_success_ratio=final_min_ratio, **kwargs, ) diff --git a/flytekit/core/launch_plan.py b/flytekit/core/launch_plan.py index 05ba393dd4..723a7e3eae 100644 --- a/flytekit/core/launch_plan.py +++ b/flytekit/core/launch_plan.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +import warnings from typing import Any, Callable, Dict, List, Optional, Type from flytekit.core import workflow as _annotated_workflow @@ -129,7 +130,8 @@ def create( labels: Optional[_common_models.Labels] = None, annotations: Optional[_common_models.Annotations] = None, raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, - max_parallelism: Optional[int] = None, + max_parallelism: Optional[int] = None, # Deprecated: Use concurrency instead + concurrency: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, trigger: Optional[LaunchPlanTriggerBase] = None, @@ -183,7 +185,8 @@ def create( labels=labels, annotations=annotations, raw_output_data_config=raw_output_data_config, - max_parallelism=max_parallelism, + concurrency=concurrency, # Pass new parameter + max_parallelism=max_parallelism, # Pass deprecated parameter security_context=security_context, trigger=trigger, overwrite_cache=overwrite_cache, @@ -213,7 +216,8 @@ def get_or_create( labels: Optional[_common_models.Labels] = None, annotations: Optional[_common_models.Annotations] = None, raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, - max_parallelism: Optional[int] = None, + concurrency: Optional[int] = None, + max_parallelism: Optional[int] = None, # Deprecated security_context: Optional[security.SecurityContext] = None, auth_role: Optional[_common_models.AuthRole] = None, trigger: Optional[LaunchPlanTriggerBase] = None, @@ -241,9 +245,10 @@ def get_or_create( :param annotations: Optional annotations to attach to executions created by this launch plan. :param raw_output_data_config: Optional location of offloaded data for things like S3, etc. :param auth_role: Add an auth role if necessary. - :param max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire - workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and - parallelism/concurrency of MapTasks is independent from this. + :param concurrency: Controls the maximum number of tasknodes that can be run in parallel for the entire + workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and + parallelism/concurrency of MapTasks is independent from this. + :param max_parallelism: [Deprecated] Use concurrency instead. :param trigger: [alpha] This is a new syntax for specifying schedules. :param overwrite_cache: If set to True, the execution will always overwrite cache :param auto_activate: If set to True, the launch plan will be activated automatically on registration. @@ -258,6 +263,7 @@ def get_or_create( or annotations is not None or raw_output_data_config is not None or auth_role is not None + or concurrency is not None or max_parallelism is not None or security_context is not None or trigger is not None @@ -296,7 +302,11 @@ def get_or_create( ("labels", labels, cached_outputs["_labels"]), ("annotations", annotations, cached_outputs["_annotations"]), ("raw_output_data_config", raw_output_data_config, cached_outputs["_raw_output_data_config"]), - ("max_parallelism", max_parallelism, cached_outputs["_max_parallelism"]), + ( + "concurrency", + concurrency if concurrency is not None else max_parallelism, + cached_outputs.get("_concurrency", cached_outputs.get("")), + ), ("security_context", security_context, cached_outputs["_security_context"]), ("overwrite_cache", overwrite_cache, cached_outputs["_overwrite_cache"]), ("auto_activate", auto_activate, cached_outputs["_auto_activate"]), @@ -326,7 +336,8 @@ def get_or_create( labels, annotations, raw_output_data_config, - max_parallelism, + concurrency=concurrency, + max_parallelism=max_parallelism, auth_role=auth_role, security_context=security_context, trigger=trigger, @@ -347,7 +358,8 @@ def __init__( labels: Optional[_common_models.Labels] = None, annotations: Optional[_common_models.Annotations] = None, raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, - max_parallelism: Optional[int] = None, + concurrency: Optional[int] = None, + max_parallelism: Optional[int] = None, # Deprecated security_context: Optional[security.SecurityContext] = None, trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, @@ -367,7 +379,14 @@ def __init__( self._labels = labels self._annotations = annotations self._raw_output_data_config = raw_output_data_config - self._max_parallelism = max_parallelism + self._concurrency = concurrency + self._max_parallelism = concurrency if concurrency is not None else max_parallelism + if max_parallelism is not None: + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) self._security_context = security_context self._trigger = trigger self._overwrite_cache = overwrite_cache @@ -385,7 +404,8 @@ def clone_with( labels: Optional[_common_models.Labels] = None, annotations: Optional[_common_models.Annotations] = None, raw_output_data_config: Optional[_common_models.RawOutputDataConfig] = None, - max_parallelism: Optional[int] = None, + concurrency: Optional[int] = None, + max_parallelism: Optional[int] = None, # Dreprecated security_context: Optional[security.SecurityContext] = None, trigger: Optional[LaunchPlanTriggerBase] = None, overwrite_cache: Optional[bool] = None, @@ -401,6 +421,7 @@ def clone_with( labels=labels or self.labels, annotations=annotations or self.annotations, raw_output_data_config=raw_output_data_config or self.raw_output_data_config, + concurrency=concurrency or self.concurrency, max_parallelism=max_parallelism or self.max_parallelism, security_context=security_context or self.security_context, trigger=trigger, @@ -466,7 +487,17 @@ def raw_output_data_config(self) -> Optional[_common_models.RawOutputDataConfig] @property def max_parallelism(self) -> Optional[int]: - return self._max_parallelism + """[Deprecated] Use concurrency instead. This property is maintained for backward compatibility""" + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + return self._concurrency + + @property + def concurrency(self) -> Optional[int]: + return self._concurrency @property def security_context(self) -> Optional[security.SecurityContext]: diff --git a/flytekit/core/options.py b/flytekit/core/options.py index 79d46c2039..0592d1f839 100644 --- a/flytekit/core/options.py +++ b/flytekit/core/options.py @@ -1,4 +1,5 @@ import typing +import warnings from dataclasses import dataclass from flytekit.models import common as common_models @@ -8,33 +9,61 @@ @dataclass class Options(object): """ - These are options that can be configured for a launchplan during registration or overridden during an execution. - For instance two people may want to run the same workflow but have the offloaded data stored in two different + These are options that can be configured for a launch plan during registration or overridden during an execution. + For instance, two people may want to run the same workflow but have the offloaded data stored in two different buckets. Or you may want labels or annotations to be different. This object is used when launching an execution in a Flyte backend, and also when registering launch plans. - Args: - labels: Custom labels to be applied to the execution resource - annotations: Custom annotations to be applied to the execution resource - security_context: Indicates security context for permissions triggered with this launch plan - raw_output_data_config: Optional location of offloaded data for things like S3, etc. - remote prefix for storage location of the form ``s3:///key...`` or - ``gcs://...`` or ``file://...``. If not specified will use the platform configured default. This is where + Attributes: + labels (typing.Optional[common_models.Labels]): Custom labels to be applied to the execution resource. + annotations (typing.Optional[common_models.Annotations]): Custom annotations to be applied to the execution resource. + raw_output_data_config (typing.Optional[common_models.RawOutputDataConfig]): Optional location of offloaded data + for things like S3, etc. Remote prefix for storage location of the form ``s3:///key...`` or + ``gcs://...`` or ``file://...``. If not specified, will use the platform-configured default. This is where the data for offloaded types is stored. - max_parallelism: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. - notifications: List of notifications for this execution. - disable_notifications: This should be set to true if all notifications are intended to be disabled for this execution. + security_context (typing.Optional[security.SecurityContext]): Indicates security context for permissions triggered + with this launch plan. + concurrency (typing.Optional[int]): Controls the maximum number of task nodes that can be run in parallel for the + entire workflow. + notifications (typing.Optional[typing.List[common_models.Notification]]): List of notifications for this execution. + disable_notifications (typing.Optional[bool]): Set to True if all notifications are intended to be disabled + for this execution. + overwrite_cache (typing.Optional[bool]): When set to True, forces the execution to overwrite any existing cached values. """ labels: typing.Optional[common_models.Labels] = None annotations: typing.Optional[common_models.Annotations] = None raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig] = None security_context: typing.Optional[security.SecurityContext] = None - max_parallelism: typing.Optional[int] = None + concurrency: typing.Optional[int] = None notifications: typing.Optional[typing.List[common_models.Notification]] = None disable_notifications: typing.Optional[bool] = None overwrite_cache: typing.Optional[bool] = None + @property + def max_parallelism(self) -> typing.Optional[int]: + """ + [Deprecated] Use concurrency instead. This property is maintained for backward compatibility + """ + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + return self.concurrency + + @max_parallelism.setter + def max_parallelism(self, value: typing.Optional[int]): + """ + Setter for max_parallelism (deprecated in favor of concurrency) + """ + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + self.concurrency = value + @classmethod def default_from( cls, diff --git a/flytekit/models/execution.py b/flytekit/models/execution.py index 0019e4d79b..6e193af6f4 100644 --- a/flytekit/models/execution.py +++ b/flytekit/models/execution.py @@ -2,6 +2,7 @@ import datetime import typing +import warnings from datetime import timezone as _timezone from typing import Optional @@ -177,6 +178,7 @@ def __init__( annotations=None, auth_role=None, raw_output_data_config=None, + concurrency: Optional[int] = None, max_parallelism: Optional[int] = None, security_context: Optional[security.SecurityContext] = None, overwrite_cache: Optional[bool] = None, @@ -195,9 +197,10 @@ def __init__( :param flytekit.models.common.Annotations annotations: Annotations to apply to the execution :param flytekit.models.common.AuthRole auth_role: The authorization method with which to execute the workflow. :param raw_output_data_config: Optional location of offloaded data for things like S3, etc. - :param max_parallelism int: Controls the maximum number of tasknodes that can be run in parallel for the entire + :param concurrency int: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and parallelism/concurrency of MapTasks is independent from this. + :param max_parallelism: Deprecated] Use concurrency instead. :param security_context: Optional security context to use for this execution. :param overwrite_cache: Optional flag to overwrite the cache for this execution. :param interruptible: Optional flag to override the default interruptible flag of the executed entity. @@ -205,6 +208,7 @@ def __init__( :param tags: Optional list of tags to apply to the execution. :param execution_cluster_label: Optional execution cluster label to use for this execution. """ + self._launch_plan = launch_plan self._metadata = metadata self._notifications = notifications @@ -213,7 +217,13 @@ def __init__( self._annotations = annotations or _common_models.Annotations({}) self._auth_role = auth_role or _common_models.AuthRole() self._raw_output_data_config = raw_output_data_config - self._max_parallelism = max_parallelism + + if max_parallelism is not None: + warnings.warn("max_parallelism is deprecated. Use concurrency instead.", DeprecationWarning, stacklevel=2) + self._concurrency = max_parallelism + else: + self._concurrency = concurrency + self._security_context = security_context self._overwrite_cache = overwrite_cache self._interruptible = interruptible @@ -281,7 +291,19 @@ def raw_output_data_config(self): @property def max_parallelism(self) -> int: - return self._max_parallelism + """ + Deprecated. Use concurrency instead. + """ + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + return self._concurrency + + @property + def concurrency(self) -> int: + return self._concurrency @property def security_context(self) -> typing.Optional[security.SecurityContext]: @@ -326,7 +348,7 @@ def to_flyte_idl(self): raw_output_data_config=self._raw_output_data_config.to_flyte_idl() if self._raw_output_data_config else None, - max_parallelism=self.max_parallelism, + max_parallelism=self._concurrency, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache, interruptible=BoolValue(value=self.interruptible) if self.interruptible is not None else None, @@ -355,7 +377,7 @@ def from_flyte_idl(cls, p): raw_output_data_config=_common_models.RawOutputDataConfig.from_flyte_idl(p.raw_output_data_config) if p.HasField("raw_output_data_config") else None, - max_parallelism=p.max_parallelism, + concurrency=p.max_parallelism, security_context=security.SecurityContext.from_flyte_idl(p.security_context) if p.security_context else None, diff --git a/flytekit/models/launch_plan.py b/flytekit/models/launch_plan.py index c828395996..37c23814ee 100644 --- a/flytekit/models/launch_plan.py +++ b/flytekit/models/launch_plan.py @@ -135,6 +135,7 @@ def __init__( annotations: _common.Annotations, auth_role: _common.AuthRole, raw_output_data_config: _common.RawOutputDataConfig, + concurrency: typing.Optional[int] = None, max_parallelism: typing.Optional[int] = None, security_context: typing.Optional[security.SecurityContext] = None, overwrite_cache: typing.Optional[bool] = None, @@ -153,9 +154,10 @@ def __init__( :param flytekit.models.common.AuthRole auth_role: The auth method with which to execute the workflow. :param flytekit.models.common.RawOutputDataConfig raw_output_data_config: Value for where to store offloaded data like Blobs and Schemas. - :param max_parallelism int: Controls the maximum number of tasknodes that can be run in parallel for the entire + :param concurrency int: Controls the maximum number of tasknodes that can be run in parallel for the entire workflow. This is useful to achieve fairness. Note: MapTasks are regarded as one unit, and parallelism/concurrency of MapTasks is independent from this. + :param max_parallelism: Deprecated. Use concurrency instead :param security_context: This can be used to add security information to a LaunchPlan, which will be used by every execution """ @@ -167,7 +169,16 @@ def __init__( self._annotations = annotations self._auth_role = auth_role self._raw_output_data_config = raw_output_data_config - self._max_parallelism = max_parallelism + self._concurrency = concurrency + self._max_parallelism = concurrency if concurrency is not None else max_parallelism + if max_parallelism is not None: + import warnings + + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) self._security_context = security_context self._overwrite_cache = overwrite_cache @@ -234,6 +245,10 @@ def raw_output_data_config(self): """ return self._raw_output_data_config + @property + def concurrency(self) -> typing.Optional[int]: + return self._concurrency + @property def max_parallelism(self) -> typing.Optional[int]: return self._max_parallelism @@ -259,7 +274,7 @@ def to_flyte_idl(self): annotations=self.annotations.to_flyte_idl(), auth_role=self.auth_role.to_flyte_idl() if self.auth_role else None, raw_output_data_config=self.raw_output_data_config.to_flyte_idl(), - max_parallelism=self.max_parallelism, + max_parallelism=self.concurrency if self.concurrency is not None else self.max_parallelism, security_context=self.security_context.to_flyte_idl() if self.security_context else None, overwrite_cache=self.overwrite_cache if self.overwrite_cache else None, ) @@ -270,6 +285,7 @@ def from_flyte_idl(cls, pb2): :param flyteidl.admin.launch_plan_pb2.LaunchPlanSpec pb2: :rtype: LaunchPlanSpec """ + auth_role = None # First check the newer field, auth_role. if pb2.auth_role is not None and (pb2.auth_role.assumable_iam_role or pb2.auth_role.kubernetes_service_account): @@ -281,6 +297,24 @@ def from_flyte_idl(cls, pb2): else: auth_role = _common.AuthRole(assumable_iam_role=pb2.auth.kubernetes_service_account) + # Handle concurrency/max_parallelism transition + concurrency = None + max_parallelism = None + + if hasattr(pb2, "concurrency"): + try: + if pb2.HasField("concurrency"): + concurrency = pb2.concurrency + except ValueError: + pass # Field doesn't exist in protobuf yet + + # Fallback to max_parallelism (deprecated field) + if hasattr(pb2, "max_parallelism"): + max_parallelism = pb2.max_parallelism + + # Use concurrency if available, otherwise use max_parallelism + final_concurrency = concurrency if concurrency is not None else max_parallelism + return cls( workflow_id=_identifier.Identifier.from_flyte_idl(pb2.workflow_id), entity_metadata=LaunchPlanMetadata.from_flyte_idl(pb2.entity_metadata), @@ -290,6 +324,7 @@ def from_flyte_idl(cls, pb2): annotations=_common.Annotations.from_flyte_idl(pb2.annotations), auth_role=auth_role, raw_output_data_config=_common.RawOutputDataConfig.from_flyte_idl(pb2.raw_output_data_config), + concurrency=final_concurrency, max_parallelism=pb2.max_parallelism, security_context=security.SecurityContext.from_flyte_idl(pb2.security_context) if pb2.security_context diff --git a/flytekit/remote/entities.py b/flytekit/remote/entities.py index fd78d4c3c4..e6dae15c94 100644 --- a/flytekit/remote/entities.py +++ b/flytekit/remote/entities.py @@ -760,6 +760,18 @@ class FlyteLaunchPlan(hash_mixin.HashOnReferenceMixin, RemoteEntity, _launch_pla """A class encapsulating a remote Flyte launch plan.""" def __init__(self, id, *args, **kwargs): + if "concurrency" in kwargs: + kwargs["max_parallelism"] = kwargs.pop("concurrency") + elif "max_parallelism" in kwargs: + import warnings + + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + self._max_parallelism = kwargs["max_parallelism"] + super(FlyteLaunchPlan, self).__init__(*args, **kwargs) # Set all the attributes we expect this class to have self._id = id @@ -770,6 +782,21 @@ def __init__(self, id, *args, **kwargs): # If fetched when creating this object, can store it here. self._flyte_workflow = None + @property + def concurrency(self) -> int: + return self._max_parallelism + + @property + def max_parallelism(self) -> int: + import warnings + + warnings.warn( + "max_parallelism is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + return self._max_parallelism + @property def name(self) -> str: return self._name diff --git a/flytekit/remote/remote.py b/flytekit/remote/remote.py index ef8b28d866..b3efb535ee 100644 --- a/flytekit/remote/remote.py +++ b/flytekit/remote/remote.py @@ -1574,7 +1574,7 @@ def _execute( annotations=options.annotations, raw_output_data_config=options.raw_output_data_config, auth_role=None, - max_parallelism=options.max_parallelism, + concurrency=options.concurrency, security_context=options.security_context, envs=common_models.Envs(envs) if envs else None, tags=tags, diff --git a/flytekit/tools/translator.py b/flytekit/tools/translator.py index a295f75078..dd922cb02c 100644 --- a/flytekit/tools/translator.py +++ b/flytekit/tools/translator.py @@ -354,6 +354,35 @@ def get_serializable_launch_plan( else: lc = None + # First, determine the concurrency value with precedence and warnings + concurrency = None + + # Check options first + if hasattr(options, "concurrency"): + concurrency = options.concurrency + elif hasattr(options, "max_parallelism") and options.max_parallelism is not None: + import warnings + + warnings.warn( + "max_parallelism in Options is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + concurrency = options.max_parallelism + + if concurrency is None: + if hasattr(entity, "concurrency"): + concurrency = entity.concurrency + elif hasattr(entity, "max_parallelism") and entity.max_parallelism is not None: + import warnings + + warnings.warn( + "max_parallelism in LaunchPlan is deprecated and will be removed in a future version. Use concurrency instead.", + DeprecationWarning, + stacklevel=2, + ) + concurrency = entity.max_parallelism + lps = _launch_plan_models.LaunchPlanSpec( workflow_id=wf_id, entity_metadata=_launch_plan_models.LaunchPlanMetadata( @@ -367,7 +396,7 @@ def get_serializable_launch_plan( annotations=options.annotations or entity.annotations or _common_models.Annotations({}), auth_role=None, raw_output_data_config=raw_prefix_config, - max_parallelism=options.max_parallelism or entity.max_parallelism, + max_parallelism=concurrency, security_context=options.security_context or entity.security_context, overwrite_cache=options.overwrite_cache or entity.overwrite_cache, ) diff --git a/plugins/flytekit-k8s-pod/tests/test_pod.py b/plugins/flytekit-k8s-pod/tests/test_pod.py index 9d3a67e2ed..99f0949fb8 100644 --- a/plugins/flytekit-k8s-pod/tests/test_pod.py +++ b/plugins/flytekit-k8s-pod/tests/test_pod.py @@ -8,7 +8,7 @@ from kubernetes.client import ApiClient from kubernetes.client.models import V1Container, V1EnvVar, V1PodSpec, V1ResourceRequirements, V1VolumeMount -from flytekit import Resources, TaskMetadata, dynamic, map_task, task +from flytekit import Resources, TaskMetadata, dynamic, map, task from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.type_engine import TypeEngine @@ -328,7 +328,7 @@ def test_map_pod_task_serialization(): def simple_pod_task(i: int): pass - mapped_task = map_task(simple_pod_task, metadata=TaskMetadata(retries=1)) + mapped_task = map(simple_pod_task, metadata=TaskMetadata(retries=1)) default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = SerializationSettings( project="project", diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index efca238dbd..9313e3d7b8 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -15,7 +15,7 @@ from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import StructuredDataset, kwtypes, map_task, task, workflow +from flytekit import StructuredDataset, kwtypes, map, task, workflow from flytekit.clients.friendly import SynchronousFlyteClient from flytekit.clis.sdk_in_container import pyflyte from flytekit.configuration import Image, ImageConfig @@ -414,7 +414,7 @@ def create_sd() -> StructuredDataset: def test_map_over_notebook_task(): @workflow def wf(a: float) -> typing.List[float]: - return map_task(nb_sub_task)(a=[a, a]) + return map(nb_sub_task)(a=[a, a]) assert wf(a=3.14) == [9.8596, 9.8596] diff --git a/pydoclint-errors-baseline.txt b/pydoclint-errors-baseline.txt index 043fcce00b..221ec1c5f7 100644 --- a/pydoclint-errors-baseline.txt +++ b/pydoclint-errors-baseline.txt @@ -88,10 +88,6 @@ flytekit/core/notification.py DOC301: Class `Email`: __init__() should not have a docstring; please combine it with the docstring of the class DOC301: Class `Slack`: __init__() should not have a docstring; please combine it with the docstring of the class -------------------- -flytekit/core/options.py - DOC601: Class `Options`: Class docstring contains fewer class attributes than actual class attributes. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) - DOC603: Class `Options`: Class docstring attributes are different from actual class attributes. (Or could be other formatting issues: https://jsh9.github.io/pydoclint/violation_codes.html#notes-on-doc103 ). Attributes in the class definition but not in the docstring: [annotations: typing.Optional[common_models.Annotations], disable_notifications: typing.Optional[bool], labels: typing.Optional[common_models.Labels], max_parallelism: typing.Optional[int], notifications: typing.Optional[typing.List[common_models.Notification]], overwrite_cache: typing.Optional[bool], raw_output_data_config: typing.Optional[common_models.RawOutputDataConfig], security_context: typing.Optional[security.SecurityContext]]. (Please read https://jsh9.github.io/pydoclint/checking_class_attributes.html on how to correctly document class attributes.) --------------------- flytekit/core/promise.py DOC301: Class `NodeOutput`: __init__() should not have a docstring; please combine it with the docstring of the class -------------------- diff --git a/tests/flytekit/integration/remote/workflows/basic/array_map.py b/tests/flytekit/integration/remote/workflows/basic/array_map.py index 8e2311af09..a2f65c7232 100644 --- a/tests/flytekit/integration/remote/workflows/basic/array_map.py +++ b/tests/flytekit/integration/remote/workflows/basic/array_map.py @@ -1,7 +1,7 @@ import typing from functools import partial -from flytekit import map_task, task, workflow +from flytekit import map, task, workflow @task @@ -12,4 +12,4 @@ def fn(x: int, y: int) -> int: @workflow def workflow_with_maptask(data: typing.List[int], y: int) -> typing.List[int]: partial_fn = partial(fn, y=y) - return map_task(partial_fn)(x=data) + return map(partial_fn)(x=data) diff --git a/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py index d5e9c32170..4ed3acec5c 100644 --- a/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py +++ b/tests/flytekit/integration/remote/workflows/basic/pydantic_wf.py @@ -1,6 +1,6 @@ from pydantic import BaseModel -from flytekit import map_task +from flytekit import map from typing import List from flytekit import task, workflow @@ -14,7 +14,7 @@ def print_float(my_float: float): @workflow def wf(bm: MyBaseModel = MyBaseModel()): - map_task(print_float)(my_float=bm.my_floats) + map(print_float)(my_float=bm.my_floats) if __name__ == "__main__": wf() diff --git a/tests/flytekit/unit/core/test_array_node.py b/tests/flytekit/unit/core/test_array_node.py index 41f5e12bc9..3b1f5efdcc 100644 --- a/tests/flytekit/unit/core/test_array_node.py +++ b/tests/flytekit/unit/core/test_array_node.py @@ -8,7 +8,7 @@ from flytekit.core.context_manager import FlyteContextManager from flytekit.configuration import Image, ImageConfig, SerializationSettings from flytekit.core.array_node import array_node -from flytekit.core.array_node_map_task import map_task +from flytekit.core.array_node_map_task import map from flytekit.models.core import identifier as identifier_models from flytekit.remote import FlyteLaunchPlan from flytekit.remote.interface import TypedInterface @@ -183,8 +183,8 @@ def grandparent_ex_wf() -> typing.List[typing.Optional[int]]: def test_map_task_wrapper(): - mapped_task = map_task(multiply)(val=[1, 3, 5], val1=[2, 4, 6], val2=[7, 8, 9]) + mapped_task = map(multiply)(val=[1, 3, 5], val1=[2, 4, 6], val2=[7, 8, 9]) assert mapped_task == [14, 96, 270] - mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6], c=[7, 8, 9]) + mapped_lp = map(lp)(a=[1, 3, 5], b=[2, 4, 6], c=[7, 8, 9]) assert mapped_lp == [14, 96, 270] diff --git a/tests/flytekit/unit/core/test_array_node_map_task.py b/tests/flytekit/unit/core/test_array_node_map_task.py index b911678a9a..34cfcfec54 100644 --- a/tests/flytekit/unit/core/test_array_node_map_task.py +++ b/tests/flytekit/unit/core/test_array_node_map_task.py @@ -9,7 +9,7 @@ import pytest from flyteidl.core import workflow_pb2 as _core_workflow -from flytekit import dynamic, map_task, task, workflow, eager, PythonFunctionTask +from flytekit import dynamic, map, task, workflow, eager, PythonFunctionTask, Resources from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core import context_manager from flytekit.core.array_node_map_task import ArrayNodeMapTask, ArrayNodeMapTaskResolver @@ -21,6 +21,7 @@ LiteralMap, LiteralOffloadedMetadata, ) +from flytekit.models.task import Resources as _resources_models from flytekit.tools.translator import get_serializable from flytekit.types.directory import FlyteDirectory @@ -62,7 +63,7 @@ def say_hello(name: str) -> str: @workflow def wf() -> List[str]: - return map_task(say_hello)(name=["abc", "def"]) + return map(say_hello)(name=["abc", "def"]) res = wf() assert res is not None @@ -79,8 +80,8 @@ def create_input_list() -> List[str]: @workflow def wf() -> List[str]: - xs = map_task(say_hello)(name=create_input_list()) - return map_task(say_hello)(name=xs) + xs = map(say_hello)(name=create_input_list()) + return map(say_hello)(name=xs) assert wf() == ["hello hello earth!!", "hello hello mars!!"] @@ -96,7 +97,7 @@ def say_hello(name: str) -> str: ctx.execution_state.with_params(mode=context_manager.ExecutionState.Mode.TASK_EXECUTION) ) ) as ctx: - t = map_task(say_hello) + t = map(say_hello) lm = TypeEngine.dict_to_literal_map(ctx, {"name": ["earth", "mars"]}, type_hints={"name": typing.List[str]}) res = t.dispatch_execute(ctx, lm) assert len(res.literals) == 1 @@ -108,7 +109,7 @@ def test_map_task_with_pickle(): def say_hello(name: typing.Any) -> str: return f"hello {name}!" - map_task(say_hello)(name=["abc", "def"]) + map(say_hello)(name=["abc", "def"]) def test_serialization(serialization_settings): @@ -116,7 +117,7 @@ def test_serialization(serialization_settings): def t1(a: int) -> int: return a + 1 - arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) + arraynode_maptask = map(t1, metadata=TaskMetadata(retries=2)) task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) assert task_spec.template.metadata.retries.retries == 2 @@ -155,7 +156,7 @@ def test_fast_serialization(serialization_settings): def t1(a: int) -> int: return a + 1 - arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) + arraynode_maptask = map(t1, metadata=TaskMetadata(retries=2)) task_spec = get_serializable(OrderedDict(), serialization_settings, arraynode_maptask) assert task_spec.template.container.args == [ @@ -225,8 +226,8 @@ def test_metadata_in_task_name(kwargs1, kwargs2, same): def say_hello(name: str) -> str: return f"hello {name}!" - t1 = map_task(say_hello, **kwargs1) - t2 = map_task(say_hello, **kwargs2) + t1 = map(say_hello, **kwargs1) + t2 = map(say_hello, **kwargs2) assert (t1.name == t2.name) is same @@ -236,7 +237,7 @@ def test_inputs_outputs_length(): def many_inputs(a: int, b: str, c: float) -> str: return f"{a} - {b} - {c}" - m = map_task(many_inputs) + m = map(many_inputs) assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": List[float]} assert ( m.name @@ -246,7 +247,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert str(r_m.python_interface) == str(m.python_interface) p1 = functools.partial(many_inputs, c=1.0) - m = map_task(p1) + m = map(p1) assert m.python_interface.inputs == {"a": List[int], "b": List[str], "c": float} assert ( m.name @@ -256,7 +257,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert str(r_m.python_interface) == str(m.python_interface) p2 = functools.partial(p1, b="hello") - m = map_task(p2) + m = map(p2) assert m.python_interface.inputs == {"a": List[int], "b": str, "c": float} assert ( m.name @@ -266,7 +267,7 @@ def many_inputs(a: int, b: str, c: float) -> str: assert str(r_m.python_interface) == str(m.python_interface) p3 = functools.partial(p2, a=1) - m = map_task(p3) + m = map(p3) assert m.python_interface.inputs == {"a": int, "b": str, "c": float} assert ( m.name @@ -283,7 +284,7 @@ def many_outputs(a: int) -> (int, str): return a, f"{a}" with pytest.raises(ValueError): - _ = map_task(many_outputs) + _ = map(many_outputs) def test_parameter_order(): @@ -303,9 +304,9 @@ def task3(c: str, a: int, b: float) -> str: param_b = [0.1, 0.2, 0.3] param_c = "c" - m1 = map_task(functools.partial(task1, c=param_c))(a=param_a, b=param_b) - m2 = map_task(functools.partial(task2, c=param_c))(a=param_a, b=param_b) - m3 = map_task(functools.partial(task3, c=param_c))(a=param_a, b=param_b) + m1 = map(functools.partial(task1, c=param_c))(a=param_a, b=param_b) + m2 = map(functools.partial(task2, c=param_c))(a=param_a, b=param_b) + m3 = map(functools.partial(task3, c=param_c))(a=param_a, b=param_b) assert m1 == m2 == m3 == ["1 - 0.1 - c", "2 - 0.2 - c", "3 - 0.3 - c"] @@ -315,7 +316,7 @@ def test_bounded_inputs_vars_order(serialization_settings): def task1(a: int, b: float, c: str) -> str: return f"{a} - {b} - {c}" - mt = map_task(functools.partial(task1, c=1.0, b="hello", a=1)) + mt = map(functools.partial(task1, c=1.0, b="hello", a=1)) mtr = ArrayNodeMapTaskResolver() args = mtr.loader_args(serialization_settings, mt) @@ -331,6 +332,7 @@ def task1(a: int, b: float, c: str) -> str: (0.5, False), ], ) + def test_raw_execute_with_min_success_ratio(min_success_ratio, should_raise_error): @task def some_task1(inputs: int) -> int: @@ -340,7 +342,7 @@ def some_task1(inputs: int) -> int: @workflow def my_wf1() -> typing.List[typing.Optional[int]]: - return map_task(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4]) + return map(some_task1, min_success_ratio=min_success_ratio)(inputs=[1, 2, 3, 4]) if should_raise_error: with pytest.raises(ValueError): @@ -356,17 +358,18 @@ def my_mappable_task(a: int) -> typing.Optional[str]: @workflow def wf(x: typing.List[int]): - map_task(my_mappable_task)(a=x).with_overrides(container_image="random:image") + map(my_mappable_task)(a=x).with_overrides(container_image="random:image") assert wf.nodes[0]._container_image == "random:image" + def test_serialization_metadata(serialization_settings): @task(interruptible=True) def t1(a: int) -> int: return a + 1 - arraynode_maptask = map_task(t1, metadata=TaskMetadata(retries=2)) + arraynode_maptask = map(t1, metadata=TaskMetadata(retries=2)) # since we manually override task metadata, the underlying task metadata will not be copied. assert not arraynode_maptask.metadata.interruptible @@ -386,7 +389,7 @@ def test_serialization_metadata2(serialization_settings): def t1(a: int) -> typing.Optional[int]: return a + 1 - arraynode_maptask = map_task( + arraynode_maptask = map( t1, min_success_ratio=0.9, concurrency=10, @@ -398,7 +401,7 @@ def t1(a: int) -> typing.Optional[int]: def wf(x: typing.List[int]): return arraynode_maptask(a=x) - full_state_array_node_map_task = map_task(PythonFunctionTaskExtension(task_config={}, task_function=t1)) + full_state_array_node_map_task = map(PythonFunctionTaskExtension(task_config={}, task_function=t1)) @workflow def wf1(x: typing.List[int]): @@ -430,7 +433,7 @@ def test_serialization_extended_resources(serialization_settings): def t1(a: int) -> int: return a + 1 - arraynode_maptask = map_task(t1) + arraynode_maptask = map(t1) @workflow def wf(x: typing.List[int]): @@ -443,32 +446,12 @@ def wf(x: typing.List[int]): assert task_spec.template.extended_resources.gpu_accelerator.device == "test_gpu" -def test_serialization_extended_resources_shared_memory(serialization_settings): - @task( - shared_memory="2Gi" - ) - def t1(a: int) -> int: - return a + 1 - - arraynode_maptask = map_task(t1) - - @workflow - def wf(x: typing.List[int]): - return arraynode_maptask(a=x) - - od = OrderedDict() - get_serializable(od, serialization_settings, wf) - task_spec = od[arraynode_maptask] - - assert task_spec.template.extended_resources.shared_memory.size_limit == "2Gi" - - def test_supported_node_type(): @task def test_task(): ... - map_task(test_task) + map(test_task) def test_unsupported_node_types(): @@ -477,21 +460,21 @@ def test_dynamic(): ... with pytest.raises(ValueError): - map_task(test_dynamic) + map(test_dynamic) @eager def test_eager(): ... with pytest.raises(ValueError): - map_task(test_eager) + map(test_eager) @workflow def test_wf(): ... with pytest.raises(ValueError): - map_task(test_wf) + map(test_wf) def test_mis_match(): @@ -509,7 +492,7 @@ def consume_directories(dirs: List[FlyteDirectory]): for path_info, other_info in d.crawl(): print(path_info) - mt = map_task(generate_directory, min_success_ratio=0.1) + mt = map(generate_directory, tolerance=0.1) @workflow def wf(): @@ -551,7 +534,7 @@ def say_hello(name: str) -> str: for index, map_input_str in enumerate(list_strs): monkeypatch.setenv("BATCH_JOB_ARRAY_INDEX_VAR_NAME", "name") monkeypatch.setenv("name", str(index)) - t = map_task(say_hello) + t = map(say_hello) res = t.dispatch_execute(ctx, lm) assert len(res.literals) == 1 assert res.literals[f"o{0}"].scalar.primitive.string_value == f"hello {map_input_str}!" diff --git a/tests/flytekit/unit/core/test_artifacts.py b/tests/flytekit/unit/core/test_artifacts.py index 34d19f50cb..666aa88e8f 100644 --- a/tests/flytekit/unit/core/test_artifacts.py +++ b/tests/flytekit/unit/core/test_artifacts.py @@ -12,7 +12,7 @@ from typing_extensions import Annotated, get_args from flytekit.configuration import Image, ImageConfig, SerializationSettings -from flytekit.core.array_node_map_task import map_task +from flytekit.core.array_node_map_task import map from flytekit.core.artifact import Artifact, Inputs, TimePartition from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager, OutputMetadataTracker from flytekit.core.interface import detect_artifact @@ -579,7 +579,7 @@ def t1(b_value: str) -> Annotated[pd.DataFrame, a1_b]: df = pd.DataFrame({"a": [1, 2, 3], "b": [b_value, b_value, b_value]}) return a1_b.create_from(df, b="dynamic!") - mt1 = map_task(t1) + mt1 = map(t1) entities = OrderedDict() mt1_s = get_serializable(entities, serialization_settings, mt1) o0 = mt1_s.template.interface.outputs["o0"] diff --git a/tests/flytekit/unit/core/test_interface.py b/tests/flytekit/unit/core/test_interface.py index 7020ba42dc..751fafad0c 100644 --- a/tests/flytekit/unit/core/test_interface.py +++ b/tests/flytekit/unit/core/test_interface.py @@ -6,7 +6,7 @@ import pytest from typing_extensions import Annotated, TypeVar # type: ignore -from flytekit import map_task, task +from flytekit import map, task from flytekit.core import context_manager from flytekit.core.docstring import Docstring from flytekit.core.interface import ( @@ -406,6 +406,6 @@ def test_map_task_interface(min_success_ratio, expected_type): def t() -> str: return "hello" - mt = map_task(t, min_success_ratio=min_success_ratio) + mt = map(t, min_success_ratio=min_success_ratio) assert mt.python_interface.outputs["o0"] == typing.List[expected_type] diff --git a/tests/flytekit/unit/core/test_launch_plan.py b/tests/flytekit/unit/core/test_launch_plan.py index 1f7023fb1c..653ebb1963 100644 --- a/tests/flytekit/unit/core/test_launch_plan.py +++ b/tests/flytekit/unit/core/test_launch_plan.py @@ -186,18 +186,18 @@ def wf(a: int, c: str) -> (int, str): assert raw_output_data_config_lp1 is raw_output_data_config_lp2 # Max parallelism - max_parallelism = 100 - max_parallelism_lp1 = launch_plan.LaunchPlan.get_or_create( + concurrency = 100 + concurrency_lp1 = launch_plan.LaunchPlan.get_or_create( workflow=wf, - name="get_or_create_max_parallelism", - max_parallelism=max_parallelism, + name="get_or_create_concurrency", + concurrency=concurrency, ) - max_parallelism_lp2 = launch_plan.LaunchPlan.get_or_create( + concurrency_lp2 = launch_plan.LaunchPlan.get_or_create( workflow=wf, - name="get_or_create_max_parallelism", - max_parallelism=max_parallelism, + name="get_or_create_concurrency", + concurrency=concurrency, ) - assert max_parallelism_lp1 is max_parallelism_lp2 + assert concurrency_lp1 is concurrency_lp2 # Labels parameters labels_model1 = Labels({"label": "foo"}) @@ -229,18 +229,18 @@ def wf(a: int, c: str) -> (int, str): assert raw_output_data_config_lp1 is raw_output_data_config_lp2 # Max parallelism - max_parallelism = 100 - max_parallelism_lp1 = launch_plan.LaunchPlan.get_or_create( + concurrency = 100 + concurrency_lp1 = launch_plan.LaunchPlan.get_or_create( workflow=wf, - name="get_or_create_max_parallelism", - max_parallelism=max_parallelism, + name="get_or_create_concurrency", + concurrency=concurrency, ) - max_parallelism_lp2 = launch_plan.LaunchPlan.get_or_create( + concurrency_lp2 = launch_plan.LaunchPlan.get_or_create( workflow=wf, - name="get_or_create_max_parallelism", - max_parallelism=max_parallelism, + name="get_or_create_concurrency", + concurrency=concurrency, ) - assert max_parallelism_lp1 is max_parallelism_lp2 + assert concurrency_lp1 is concurrency_lp2 # Labels parameters labels_model1 = Labels({"label": "foo"}) @@ -272,18 +272,18 @@ def wf(a: int, c: str) -> (int, str): assert raw_output_data_config_lp1 is raw_output_data_config_lp2 # Max parallelism - max_parallelism = 100 - max_parallelism_lp1 = launch_plan.LaunchPlan.get_or_create( + concurrency = 100 + concurrency_lp1 = launch_plan.LaunchPlan.get_or_create( workflow=wf, - name="get_or_create_max_parallelism", - max_parallelism=max_parallelism, + name="get_or_create_concurrency", + concurrency=concurrency, ) - max_parallelism_lp2 = launch_plan.LaunchPlan.get_or_create( + concurrency_lp2 = launch_plan.LaunchPlan.get_or_create( workflow=wf, - name="get_or_create_max_parallelism", - max_parallelism=max_parallelism, + name="get_or_create_concurrency", + concurrency=concurrency, ) - assert max_parallelism_lp1 is max_parallelism_lp2 + assert concurrency_lp1 is concurrency_lp2 # Overwrite cache overwrite_cache = True diff --git a/tests/flytekit/unit/core/test_node_creation.py b/tests/flytekit/unit/core/test_node_creation.py index 29fa758801..3f7d77f288 100644 --- a/tests/flytekit/unit/core/test_node_creation.py +++ b/tests/flytekit/unit/core/test_node_creation.py @@ -7,7 +7,7 @@ from kubernetes.client import V1PodSpec, V1Container, V1EnvVar import flytekit.configuration -from flytekit import Resources, map_task, PodTemplate +from flytekit import Resources, map, PodTemplate from flytekit.configuration import Image, ImageConfig from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.node_creation import create_node @@ -216,7 +216,7 @@ def t1(a: str) -> str: @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: - mappy = map_task(t1) + mappy = map(t1) map_node = mappy(a=a).with_overrides(requests=Resources(cpu="1", mem="100", ephemeral_storage="500Mi")) return map_node @@ -245,7 +245,7 @@ def t1(a: str) -> str: @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: - mappy = map_task(t1) + mappy = map(t1) map_node = mappy(a=a).with_overrides(limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi")) return map_node @@ -273,7 +273,7 @@ def t1(a: str) -> str: @workflow def my_wf(a: typing.List[str]) -> typing.List[str]: - mappy = map_task(t1) + mappy = map(t1) map_node = mappy(a=a).with_overrides( requests=Resources(cpu="1", mem="100", ephemeral_storage="500Mi"), limits=Resources(cpu="2", mem="200", ephemeral_storage="1Gi"), diff --git a/tests/flytekit/unit/core/test_partials.py b/tests/flytekit/unit/core/test_partials.py index c785a07ec7..c28e0b9ea0 100644 --- a/tests/flytekit/unit/core/test_partials.py +++ b/tests/flytekit/unit/core/test_partials.py @@ -8,7 +8,7 @@ import flytekit.configuration from flytekit.configuration import Image, ImageConfig from flytekit.core.array_node_map_task import ArrayNodeMapTaskResolver -from flytekit.core.array_node_map_task import map_task as array_node_map_task +from flytekit.core.array_node_map_task import map as array_node_map_task from flytekit.core.dynamic_workflow_task import dynamic from flytekit.core.legacy_map_task import MapTaskResolver from flytekit.core.legacy_map_task import map_task as legacy_map_task diff --git a/tests/flytekit/unit/core/test_type_hints.py b/tests/flytekit/unit/core/test_type_hints.py index ec977020b0..a1eabf65bf 100644 --- a/tests/flytekit/unit/core/test_type_hints.py +++ b/tests/flytekit/unit/core/test_type_hints.py @@ -19,7 +19,7 @@ import flytekit import flytekit.configuration -from flytekit import Secret, SQLTask, dynamic, kwtypes, map_task +from flytekit import Secret, SQLTask, dynamic, kwtypes, map from flytekit.configuration import FastSerializationSettings, Image, ImageConfig from flytekit.core import context_manager, launch_plan, promise from flytekit.core.condition import conditional @@ -598,7 +598,7 @@ def t2(x: typing.List[int]) -> int: @workflow def my_wf(a: typing.List[int]) -> int: - x = map_task(t1, metadata=TaskMetadata(retries=1))(a=a) + x = map(t1, metadata=TaskMetadata(retries=1))(a=a) return t2(x=x) x = my_wf(a=[5, 6]) diff --git a/tests/flytekit/unit/remote/test_remote.py b/tests/flytekit/unit/remote/test_remote.py index 9911cad02f..e690787a47 100644 --- a/tests/flytekit/unit/remote/test_remote.py +++ b/tests/flytekit/unit/remote/test_remote.py @@ -17,7 +17,7 @@ from mock import ANY, MagicMock, patch import flytekit.configuration -from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, map_task, dynamic, eager +from flytekit import CronSchedule, ImageSpec, LaunchPlan, WorkflowFailurePolicy, task, workflow, reference_task, map, dynamic, eager from flytekit.configuration import Config, DefaultImages, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import PythonTask from flytekit.core.context_manager import FlyteContextManager @@ -726,7 +726,7 @@ def t1(x: int, y: int) -> int: @workflow def w() -> int: - return map_task(partial(t1, y=2))(x=[1, 2, 3]) + return map(partial(t1, y=2))(x=[1, 2, 3]) _, target_dict = _get_pickled_target_dict(w) assert ( diff --git a/tests/flytekit/unit/types/directory/test_listdir.py b/tests/flytekit/unit/types/directory/test_listdir.py index 0987456907..2f49faaee5 100644 --- a/tests/flytekit/unit/types/directory/test_listdir.py +++ b/tests/flytekit/unit/types/directory/test_listdir.py @@ -1,7 +1,7 @@ import tempfile from pathlib import Path -from flytekit import FlyteDirectory, FlyteFile, map_task, task, workflow +from flytekit import FlyteDirectory, FlyteFile, map, task, workflow def test_listdir(): @task @@ -26,6 +26,6 @@ def list_dir(dir: FlyteDirectory) -> list[FlyteFile]: def wf() -> list[str]: tmpdir = setup() files = list_dir(dir=tmpdir) - return map_task(read_file)(file=files) + return map(read_file)(file=files) assert wf() == ["Hello, World!"]