diff --git a/src/zenml/orchestrators/cache_utils.py b/src/zenml/orchestrators/cache_utils.py index c38850a168..30a07251d6 100644 --- a/src/zenml/orchestrators/cache_utils.py +++ b/src/zenml/orchestrators/cache_utils.py @@ -127,3 +127,17 @@ def get_cached_step_run(cache_key: str) -> Optional["StepRunResponse"]: if cache_candidates: return cache_candidates[0] return None + + +def is_valid_cached_step_run( + step_run: "StepRunResponse", cache_key: str +) -> bool: + """Check if a given step run is a valid cache candidate. + + A step run is a valid cache candidate if it has the same cache key and is successful. + """ + return ( + step_run.cache_key == cache_key + and step_run.status == ExecutionStatus.COMPLETED + and step_run.workspace.id == Client().active_workspace.id + ) diff --git a/src/zenml/orchestrators/publish_utils.py b/src/zenml/orchestrators/publish_utils.py index f47156c15d..b7734ea634 100644 --- a/src/zenml/orchestrators/publish_utils.py +++ b/src/zenml/orchestrators/publish_utils.py @@ -54,6 +54,28 @@ def publish_successful_step_run( ) +def publish_delayed_cached_step_run( + step_run_id: "UUID", output_artifact_ids: Dict[str, List["UUID"]] +) -> "StepRunResponse": + """Publishes a successful step run. + + Args: + step_run_id: The ID of the step run to update. + output_artifact_ids: The output artifact IDs for the step run. + + Returns: + The updated step run. + """ + return Client().zen_store.update_run_step( + step_run_id=step_run_id, + step_run_update=StepRunUpdate( + status=ExecutionStatus.CACHED, + end_time=utc_now(), + outputs=output_artifact_ids, + ), + ) + + def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse": """Publishes a failed step run. diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index f18f1c649a..3b1626348a 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -75,6 +75,14 @@ logger = get_logger(__name__) +from uuid import UUID + + +class CacheResult: + def __init__(self, step_run_id: Optional[UUID] = None) -> None: + self.step_run_id = step_run_id + + class StepRunner: """Class to run steps.""" @@ -225,6 +233,64 @@ def run( step_exception=None, ) + if isinstance(return_values, CacheResult): + from zenml.orchestrators import cache_utils + + cache_key = step_run.cache_key + + cached_step_run = None + + if return_values.step_run_id: + try: + candidate = Client().get_run_step( + return_values.step_run_id + ) + except KeyError: + raise StepInterfaceError( + "Unable to find cached candidate step" + f"`{return_values.step_run_id}`." + ) + if cache_utils.is_valid_cached_step_run( + candidate, cache_key + ): + cached_step_run = candidate + else: + raise StepInterfaceError( + "Invalid cache candidate " + f"`{return_values.step_run_id}` " + "specified. A valid cache candidate " + "must have the same cache key and has " + "been successfully executed." + ) + elif ( + cached_step_run + := cache_utils.get_cached_step_run( + cache_key=cache_key + ) + ): + cached_step_run = cached_step_run + else: + raise StepInterfaceError( + "Unable to find cached step run for cache " + f"key `{cache_key}`." + ) + + # TODO: We need to be able to update this + # request.original_step_run_id = cached_step_run.id + + output_artifact_ids = { + output_name: [ + artifact.id for artifact in artifacts + ] + for output_name, artifacts in cached_step_run.outputs.items() + } + + publish_successful_step_run( + step_run_id=step_run_info.step_run_id, + output_artifact_ids=output_artifact_ids, + ) + return + # Store and publish the output artifacts of the step function. output_data = self._validate_outputs( return_values, output_annotations