Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

POC: Cache step while running #3350

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/zenml/orchestrators/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,17 @@ def get_cached_step_run(cache_key: str) -> Optional["StepRunResponse"]:
if cache_candidates:
return cache_candidates[0]
return None


def is_valid_cached_step_run(
step_run: "StepRunResponse", cache_key: str
) -> bool:
"""Check if a given step run is a valid cache candidate.

A step run is a valid cache candidate if it has the same cache key and is successful.
"""
return (
step_run.cache_key == cache_key
and step_run.status == ExecutionStatus.COMPLETED
and step_run.workspace.id == Client().active_workspace.id
)
22 changes: 22 additions & 0 deletions src/zenml/orchestrators/publish_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,28 @@ def publish_successful_step_run(
)


def publish_delayed_cached_step_run(
step_run_id: "UUID", output_artifact_ids: Dict[str, List["UUID"]]
) -> "StepRunResponse":
"""Publishes a successful step run.

Args:
step_run_id: The ID of the step run to update.
output_artifact_ids: The output artifact IDs for the step run.

Returns:
The updated step run.
"""
return Client().zen_store.update_run_step(
step_run_id=step_run_id,
step_run_update=StepRunUpdate(
status=ExecutionStatus.CACHED,
end_time=utc_now(),
outputs=output_artifact_ids,
),
)


def publish_failed_step_run(step_run_id: "UUID") -> "StepRunResponse":
"""Publishes a failed step run.

Expand Down
66 changes: 66 additions & 0 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@
logger = get_logger(__name__)


from uuid import UUID


class CacheResult:
def __init__(self, step_run_id: Optional[UUID] = None) -> None:
self.step_run_id = step_run_id


class StepRunner:
"""Class to run steps."""

Expand Down Expand Up @@ -225,6 +233,64 @@ def run(
step_exception=None,
)

if isinstance(return_values, CacheResult):
from zenml.orchestrators import cache_utils

cache_key = step_run.cache_key

cached_step_run = None

if return_values.step_run_id:
try:
candidate = Client().get_run_step(
return_values.step_run_id
)
except KeyError:
raise StepInterfaceError(
"Unable to find cached candidate step"
f"`{return_values.step_run_id}`."
)
if cache_utils.is_valid_cached_step_run(
candidate, cache_key
):
cached_step_run = candidate
else:
raise StepInterfaceError(
"Invalid cache candidate "
f"`{return_values.step_run_id}` "
"specified. A valid cache candidate "
"must have the same cache key and has "
"been successfully executed."
)
elif (
cached_step_run
:= cache_utils.get_cached_step_run(
cache_key=cache_key
)
):
cached_step_run = cached_step_run
else:
raise StepInterfaceError(
"Unable to find cached step run for cache "
f"key `{cache_key}`."
)

# TODO: We need to be able to update this
# request.original_step_run_id = cached_step_run.id

output_artifact_ids = {
output_name: [
artifact.id for artifact in artifacts
]
for output_name, artifacts in cached_step_run.outputs.items()
}

publish_successful_step_run(
step_run_id=step_run_info.step_run_id,
output_artifact_ids=output_artifact_ids,
)
return

# Store and publish the output artifacts of the step function.
output_data = self._validate_outputs(
return_values, output_annotations
Expand Down