Skip to content

Commit

Permalink
Merge branch 'master' into pandas-in-pydantic
Browse files Browse the repository at this point in the history
  • Loading branch information
Future-Outlier committed Feb 17, 2025
2 parents c2579e4 + 66d4aed commit b5f2a6f
Show file tree
Hide file tree
Showing 66 changed files with 1,774 additions and 302 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,4 @@ coverage.xml

# Version file is auto-generated by setuptools_scm
flytekit/_version.py
testing
4 changes: 4 additions & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@
:toctree: generated/
HashMethod
Cache
CachePolicy
VersionParameters
Artifacts
=========
Expand Down Expand Up @@ -223,6 +226,7 @@
from flytekit.core.artifact import Artifact
from flytekit.core.base_sql_task import SQLTask
from flytekit.core.base_task import SecurityContext, TaskMetadata, kwtypes
from flytekit.core.cache import Cache, CachePolicy, VersionParameters
from flytekit.core.checkpointer import Checkpoint
from flytekit.core.condition import conditional
from flytekit.core.container_task import ContainerTask
Expand Down
14 changes: 10 additions & 4 deletions flytekit/clients/auth_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,12 @@ def upgrade_channel_to_proxy_authenticated(cfg: PlatformConfig, in_channel: grpc
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""

def authenticator_factory():
return get_proxy_authenticator(cfg)

if cfg.proxy_command:
proxy_authenticator = get_proxy_authenticator(cfg)
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(proxy_authenticator))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))
else:
return in_channel

Expand All @@ -137,8 +140,11 @@ def upgrade_channel_to_authenticated(cfg: PlatformConfig, in_channel: grpc.Chann
:param in_channel: grpc.Channel Precreated channel
:return: grpc.Channel. New composite channel
"""
authenticator = get_authenticator(cfg, RemoteClientConfigStore(in_channel))
return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator))

def authenticator_factory():
return get_authenticator(cfg, RemoteClientConfigStore(in_channel))

return grpc.intercept_channel(in_channel, AuthUnaryInterceptor(authenticator_factory))


def get_authenticated_channel(cfg: PlatformConfig) -> grpc.Channel:
Expand Down
17 changes: 12 additions & 5 deletions flytekit/clients/grpc_utils/auth_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,22 @@ class AuthUnaryInterceptor(grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamCli
is needed.
"""

def __init__(self, authenticator: Authenticator):
self._authenticator = authenticator
def __init__(self, get_authenticator: typing.Callable[[], Authenticator]):
self._get_authenticator = get_authenticator
self._authenticator = None

@property
def authenticator(self) -> Authenticator:
if self._authenticator is None:
self._authenticator = self._get_authenticator()
return self._authenticator

def _call_details_with_auth_metadata(self, client_call_details: grpc.ClientCallDetails) -> grpc.ClientCallDetails:
"""
Returns new ClientCallDetails with metadata added.
"""
metadata = client_call_details.metadata
auth_metadata = self._authenticator.fetch_grpc_call_auth_metadata()
auth_metadata = self.authenticator.fetch_grpc_call_auth_metadata()
if auth_metadata:
metadata = []
if client_call_details.metadata:
Expand Down Expand Up @@ -64,7 +71,7 @@ def intercept_unary_unary(
if not hasattr(e, "code"):
raise e
if e.code() == grpc.StatusCode.UNAUTHENTICATED or e.code() == grpc.StatusCode.UNKNOWN:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return fut
Expand All @@ -76,7 +83,7 @@ def intercept_unary_stream(self, continuation, client_call_details, request):
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
c: grpc.Call = continuation(updated_call_details, request)
if c.code() == grpc.StatusCode.UNAUTHENTICATED:
self._authenticator.refresh_credentials()
self.authenticator.refresh_credentials()
updated_call_details = self._call_details_with_auth_metadata(client_call_details)
return continuation(updated_call_details, request)
return c
4 changes: 1 addition & 3 deletions flytekit/clis/sdk_in_container/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,4 @@ def init(template, project_name):
processed_contents = project_template_regex.sub(project_name_bytes, zip_contents)
dest_file.write(processed_contents)

click.echo(
f"Visit the {project_name} directory and follow the next steps in the Getting started guide (https://docs.flyte.org/en/latest/user_guide/getting_started_with_workflow_development/index.html) to proceed."
)
click.echo(f"Project initialized in directory {project_name}.")
18 changes: 17 additions & 1 deletion flytekit/clis/sdk_in_container/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,9 +1052,25 @@ def _create_command(
r = run_level_params.remote_instance()
flyte_ctx = r.context

final_inputs_with_defaults = loaded_entity.python_interface.inputs_with_defaults
if isinstance(loaded_entity, LaunchPlan):
# For LaunchPlans it is essential to handle fixed inputs and default inputs in a special way
# Fixed inputs are inputs that are always passed to the launch plan and cannot be overridden
# Default inputs are inputs that are optional and have a default value
# The final inputs to the launch plan are a combination of the fixed inputs and the default inputs
all_inputs = loaded_entity.python_interface.inputs_with_defaults
default_inputs = loaded_entity.saved_inputs
pmap = loaded_entity.parameters
final_inputs_with_defaults = {}
for name, _ in pmap.parameters.items():
_type, v = all_inputs[name]
if name in default_inputs:
v = default_inputs[name]
final_inputs_with_defaults[name] = _type, v

# Add options for each of the workflow inputs
params = []
for input_name, input_type_val in loaded_entity.python_interface.inputs_with_defaults.items():
for input_name, input_type_val in final_inputs_with_defaults.items():
literal_var = loaded_entity.interface.inputs.get(input_name)
python_type, default_val = input_type_val
required = type(None) not in get_args(python_type) and default_val is None
Expand Down
12 changes: 11 additions & 1 deletion flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
"""

import os
from typing import Optional, Protocol, runtime_checkable
from typing import List, Optional, Protocol, runtime_checkable

from click import Group
from importlib_metadata import entry_points

from flytekit import CachePolicy
from flytekit.configuration import Config, get_config_file
from flytekit.loggers import logger
from flytekit.remote import FlyteRemote
Expand Down Expand Up @@ -53,6 +54,10 @@ def get_default_image() -> Optional[str]:
def get_auth_success_html(endpoint: str) -> Optional[str]:
"""Get default success html for auth. Return None to use flytekit's default success html."""

@staticmethod
def get_default_cache_policies() -> List[CachePolicy]:
"""Get default cache policies for tasks."""


class FlytekitPlugin:
@staticmethod
Expand Down Expand Up @@ -103,6 +108,11 @@ def get_auth_success_html(endpoint: str) -> Optional[str]:
"""Get default success html. Return None to use flytekit's default success html."""
return None

@staticmethod
def get_default_cache_policies() -> List[CachePolicy]:
"""Get default cache policies for tasks."""
return []


def _get_plugin_from_entrypoint():
"""Get plugin from entrypoint."""
Expand Down
97 changes: 97 additions & 0 deletions flytekit/core/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import hashlib
from dataclasses import dataclass
from typing import Callable, Generic, List, Optional, Protocol, Tuple, Union, runtime_checkable

from typing_extensions import ParamSpec, TypeVar

from flytekit.core.pod_template import PodTemplate
from flytekit.image_spec.image_spec import ImageSpec

P = ParamSpec("P")
FuncOut = TypeVar("FuncOut")


@dataclass
class VersionParameters(Generic[P, FuncOut]):
"""
Parameters used for version hash generation.
param func: The function to generate a version for. This is an optional parameter and can be any callable
that matches the specified parameter and return types.
:type func: Optional[Callable[P, FuncOut]]
:param container_image: The container image to generate a version for. This can be a string representing the
image name or an ImageSpec object.
:type container_image: Optional[Union[str, ImageSpec]]
"""

func: Callable[P, FuncOut]
container_image: Optional[Union[str, ImageSpec]] = None
pod_template: Optional[PodTemplate] = None
pod_template_name: Optional[str] = None


@runtime_checkable
class CachePolicy(Protocol):
def get_version(self, salt: str, params: VersionParameters) -> str: ...


@dataclass
class Cache:
"""
Cache configuration for a task.
:param version: The version of the task. If not provided, the version will be generated based on the cache policies.
:type version: Optional[str]
:param serialize: Boolean that indicates if identical (ie. same inputs) instances of this task should be executed in
serial when caching is enabled. This means that given multiple concurrent executions over identical inputs,
only a single instance executes and the rest wait to reuse the cached results.
:type serialize: bool
:param ignored_inputs: A tuple of input names to ignore when generating the version hash.
:type ignored_inputs: Union[Tuple[str, ...], str]
:param salt: A salt used in the hash generation.
:type salt: str
:param policies: A list of cache policies to generate the version hash.
:type policies: Optional[Union[List[CachePolicy], CachePolicy]]
"""

version: Optional[str] = None
serialize: bool = False
ignored_inputs: Union[Tuple[str, ...], str] = ()
salt: str = ""
policies: Optional[Union[List[CachePolicy], CachePolicy]] = None

def __post_init__(self):
if isinstance(self.ignored_inputs, str):
self._ignored_inputs = (self.ignored_inputs,)
else:
self._ignored_inputs = self.ignored_inputs

# Normalize policies so that self._policies is always a list
if self.policies is None:
from flytekit.configuration.plugin import get_plugin

self._policies = get_plugin().get_default_cache_policies()
elif isinstance(self.policies, CachePolicy):
self._policies = [self.policies]

if self.version is None and not self._policies:
raise ValueError("If version is not defined then at least one cache policy needs to be set")

def get_ignored_inputs(self) -> Tuple[str, ...]:
return self._ignored_inputs

def get_version(self, params: VersionParameters) -> str:
if self.version is not None:
return self.version

task_hash = ""
for policy in self._policies:
try:
task_hash += policy.get_version(self.salt, params)
except Exception as e:
raise ValueError(
f"Failed to generate version for cache policy {policy}. Please consider setting the version in the Cache definition, e.g. Cache(version='v1.2.3')"
) from e

hash_obj = hashlib.sha256(task_hash.encode())
return hash_obj.hexdigest()
4 changes: 4 additions & 0 deletions flytekit/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,7 @@
CACHE_KEY_METADATA = "cache-key-metadata"

SERIALIZATION_FORMAT = "serialization-format"

# Shared memory mount name and path
SHARED_MEMORY_MOUNT_NAME = "flyte-shared-memory"
SHARED_MEMORY_MOUNT_PATH = "/dev/shm"
10 changes: 8 additions & 2 deletions flytekit/core/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import datetime
import typing
from typing import Any, Dict, List, Optional, Union
from typing import Literal as L

from flyteidl.core import tasks_pb2

from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, convert_resources_to_resource_model
from flytekit.core.resources import Resources, construct_extended_resources, convert_resources_to_resource_model
from flytekit.core.utils import _dnsify
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.loggers import logger
Expand Down Expand Up @@ -193,6 +194,7 @@ def with_overrides(
cache: Optional[bool] = None,
cache_version: Optional[str] = None,
cache_serialize: Optional[bool] = None,
shared_memory: Optional[Union[L[True], str]] = None,
pod_template: Optional[PodTemplate] = None,
*args,
**kwargs,
Expand Down Expand Up @@ -240,7 +242,11 @@ def with_overrides(

if accelerator is not None:
assert_not_promise(accelerator, "accelerator")
self._extended_resources = tasks_pb2.ExtendedResources(gpu_accelerator=accelerator.to_flyte_idl())

if shared_memory is not None:
assert_not_promise(shared_memory, "shared_memory")

self._extended_resources = construct_extended_resources(accelerator=accelerator, shared_memory=shared_memory)

self._override_node_metadata(name, timeout, retries, interruptible, cache, cache_version, cache_serialize)

Expand Down
4 changes: 2 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ async def binding_data_from_python_std(
if transformer_override and hasattr(transformer_override, "extract_types_or_metadata"):
_, v_type = transformer_override.extract_types_or_metadata(t_value_type) # type: ignore
else:
_, v_type = DictTransformer.extract_types_or_metadata(t_value_type) # type: ignore
_, v_type = DictTransformer.extract_types(cast(typing.Type[dict], t_value_type))
m = _literals_models.BindingDataMap(
bindings={
k: await binding_data_from_python_std(
Expand Down Expand Up @@ -1482,7 +1482,7 @@ def flyte_entity_call_handler(
# call the blocking version of the async call handler
# This is a recursive call, the async handler also calls this function, so this conditional must match
# the one in the async function perfectly, otherwise you'll get infinite recursion.
loop_manager.run_sync(async_flyte_entity_call_handler, entity, **kwargs)
return loop_manager.run_sync(async_flyte_entity_call_handler, entity, **kwargs)

if ctx.execution_state and (
ctx.execution_state.mode == ExecutionState.Mode.TASK_EXECUTION
Expand Down
12 changes: 7 additions & 5 deletions flytekit/core/python_auto_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from abc import ABC
from dataclasses import dataclass
from typing import Callable, Dict, List, Optional, TypeVar, Union
from typing import Literal as L

from flyteidl.core import tasks_pb2

Expand All @@ -13,7 +14,7 @@
from flytekit.core.base_task import PythonTask, TaskMetadata, TaskResolverMixin
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.pod_template import PodTemplate
from flytekit.core.resources import Resources, ResourceSpec
from flytekit.core.resources import Resources, ResourceSpec, construct_extended_resources
from flytekit.core.tracked_abc import FlyteTrackedABC
from flytekit.core.tracker import TrackedInstance, extract_task_module
from flytekit.core.utils import _get_container_definition, _serialize_pod_spec, timeit
Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
pod_template: Optional[PodTemplate] = None,
pod_template_name: Optional[str] = None,
accelerator: Optional[BaseAccelerator] = None,
shared_memory: Optional[Union[L[True], str]] = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -78,6 +80,8 @@ def __init__(
:param pod_template: Custom PodTemplate for this task.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param accelerator: The accelerator to use for this task.
:param shared_memory: If True, then shared memory will be attached to the container where the size is equal
to the allocated memory. If str, then the shared memory is set to that size.
"""
sec_ctx = None
if secret_requests:
Expand Down Expand Up @@ -128,6 +132,7 @@ def __init__(

self.pod_template = pod_template
self.accelerator = accelerator
self.shared_memory = shared_memory

@property
def task_resolver(self) -> TaskResolverMixin:
Expand Down Expand Up @@ -250,10 +255,7 @@ def get_extended_resources(self, settings: SerializationSettings) -> Optional[ta
"""
Returns the extended resources to allocate to the task on hosted Flyte.
"""
if self.accelerator is None:
return None

return tasks_pb2.ExtendedResources(gpu_accelerator=self.accelerator.to_flyte_idl())
return construct_extended_resources(accelerator=self.accelerator, shared_memory=self.shared_memory)


class DefaultTaskResolver(TrackedInstance, TaskResolverMixin):
Expand Down
Loading

0 comments on commit b5f2a6f

Please sign in to comment.