Skip to content

Commit

Permalink
Model links lazy evaluation in pipeline code (zenml-io#2205)
Browse files Browse the repository at this point in the history
* MV data lazy loading in pipelines

* add tests

* new template ref

* use `BaseModel`

* Auto-update of Starter template

* update to `v7` syntax

* update test signatures

* Auto-update of E2E template

* update test signatures

* wandb lint

* lint

* remove leftover

* Apply suggestions from code review

Co-authored-by: Alex Strick van Linschoten <[email protected]>

* renaming

* model lazy load in `model`

* metadata lazy load in `metadata`

* implement Michael's suggestions

---------

Co-authored-by: GitHub Actions <[email protected]>
Co-authored-by: Alex Strick van Linschoten <[email protected]>
  • Loading branch information
3 people authored and adtygan committed Mar 20, 2024
1 parent 097111c commit 2da0450
Show file tree
Hide file tree
Showing 23 changed files with 460 additions and 87 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/update-templates-to-examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
github.issues.createComment({
github.rest.issues.createComment({
issue_number: ${{ github.event.pull_request.number }},
owner: context.repo.owner,
repo: context.repo.repo,
Expand Down Expand Up @@ -165,7 +165,7 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
github.issues.createComment({
github.rest.issues.createComment({
issue_number: ${{ github.event.pull_request.number }},
owner: context.repo.owner,
repo: context.repo.repo,
Expand Down Expand Up @@ -237,7 +237,7 @@ jobs:
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |-
github.issues.createComment({
github.rest.issues.createComment({
issue_number: ${{ github.event.pull_request.number }},
owner: context.repo.owner,
repo: context.repo.repo,
Expand Down
4 changes: 2 additions & 2 deletions examples/e2e/steps/deployment/deployment_deploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def deployment_deploy() -> (
# deploy predictor service
deployment_service = mlflow_model_registry_deployer_step.entrypoint(
registry_model_name=model_version.name,
registry_model_version=model_version.metadata[
registry_model_version=model_version.run_metadata[
"model_registry_version"
],
].value,
replace_existing=True,
)
else:
Expand Down
14 changes: 8 additions & 6 deletions examples/e2e/steps/promotion/promote_with_metric_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,17 +90,17 @@ def promote_with_metric_compare(
logger.info(f"Current model version was promoted to '{target_env}'.")

# Promote in Model Registry
latest_version_model_registry_number = latest_version.metadata[
latest_version_model_registry_number = latest_version.run_metadata[
"model_registry_version"
]
].value
if current_version_number is None:
current_version_model_registry_number = (
latest_version_model_registry_number
)
else:
current_version_model_registry_number = current_version.metadata[
"model_registry_version"
]
current_version_model_registry_number = (
current_version.run_metadata["model_registry_version"].value
)
promote_in_model_registry(
latest_version=latest_version_model_registry_number,
current_version=current_version_model_registry_number,
Expand All @@ -109,7 +109,9 @@ def promote_with_metric_compare(
)
promoted_version = latest_version_model_registry_number
else:
promoted_version = current_version.metadata["model_registry_version"]
promoted_version = current_version.run_metadata[
"model_registry_version"
].value

logger.info(
f"Current model version in `{target_env}` is `{promoted_version}` registered in Model Registry"
Expand Down
1 change: 1 addition & 0 deletions src/zenml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
from zenml.artifacts.artifact_config import ArtifactConfig
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.model.model_version import ModelVersion
from zenml.model.utils import log_model_version_metadata
from zenml.new.pipelines.pipeline_context import get_pipeline_context
from zenml.new.pipelines.pipeline_decorator import pipeline
from zenml.new.steps.step_decorator import step
Expand Down
2 changes: 1 addition & 1 deletion src/zenml/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _model_version_to_print(
"number": model_version.number,
"description": model_version.description,
"stage": model_version.stage,
"metadata": model_version.to_model_version().metadata,
"metadata": model_version.to_model_version().run_metadata,
"tags": [t.name for t in model_version.tags],
"data_artifacts_count": len(model_version.data_artifact_ids),
"model_artifacts_count": len(model_version.model_artifact_ids),
Expand Down
2 changes: 2 additions & 0 deletions src/zenml/config/step_configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from zenml.config.source import Source, convert_source_validator
from zenml.config.strict_base_model import StrictBaseModel
from zenml.logger import get_logger
from zenml.model.lazy_load import ModelVersionDataLazyLoader
from zenml.model.model_version import ModelVersion
from zenml.utils import deprecation_utils

Expand Down Expand Up @@ -152,6 +153,7 @@ class PartialStepConfiguration(StepConfigurationUpdate):
name: str
caching_parameters: Mapping[str, Any] = {}
external_input_artifacts: Mapping[str, ExternalArtifactConfiguration] = {}
model_artifacts_or_metadata: Mapping[str, ModelVersionDataLazyLoader] = {}
outputs: Mapping[str, PartialArtifactConfiguration] = {}

# Override the deprecation validator as we do not want to deprecate the
Expand Down
64 changes: 64 additions & 0 deletions src/zenml/metadata/lazy_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Run Metadata Lazy Loader definition."""

from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from zenml.model.model_version import ModelVersion
from zenml.models import RunMetadataResponse


class RunMetadataLazyGetter:
"""Run Metadata Lazy Getter helper class.
It serves the purpose to feed back to the user the metadata
lazy loader wrapper for any given key, if called inside a pipeline
design time context.
"""

def __init__(
self,
_lazy_load_model_version: "ModelVersion",
_lazy_load_artifact_name: Optional[str],
_lazy_load_artifact_version: Optional[str],
):
"""Initialize a RunMetadataLazyGetter.
Args:
_lazy_load_model_version: The model version.
_lazy_load_artifact_name: The artifact name.
_lazy_load_artifact_version: The artifact version.
"""
self._lazy_load_model_version = _lazy_load_model_version
self._lazy_load_artifact_name = _lazy_load_artifact_name
self._lazy_load_artifact_version = _lazy_load_artifact_version

def __getitem__(self, key: str) -> "RunMetadataResponse":
"""Get the metadata for the given key.
Args:
key: The metadata key.
Returns:
The metadata lazy loader wrapper for the given key.
"""
from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse

return LazyRunMetadataResponse(
_lazy_load_model_version=self._lazy_load_model_version,
_lazy_load_artifact_name=self._lazy_load_artifact_name,
_lazy_load_artifact_version=self._lazy_load_artifact_version,
_lazy_load_metadata_name=key,
)
5 changes: 4 additions & 1 deletion src/zenml/metadata/metadata_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@
# permissions and limitations under the License.
"""Custom types that can be used as metadata of ZenML artifacts."""

from typing import Any, Dict, List, Set, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple, Union

from zenml.utils.enum_utils import StrEnum

if TYPE_CHECKING:
pass


class Uri(str):
"""Special string class to indicate a URI."""
Expand Down
34 changes: 34 additions & 0 deletions src/zenml/model/lazy_load.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at:
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""Model Version Data Lazy Loader definition."""

from typing import Optional

from pydantic import BaseModel

from zenml.model.model_version import ModelVersion


class ModelVersionDataLazyLoader(BaseModel):
"""Model Version Data Lazy Loader helper class.
It helps the inner codes to fetch proper artifact,
model version metadata or artifact metadata from the
model version during runtime time of the step.
"""

model_version: ModelVersion
artifact_name: Optional[str] = None
artifact_version: Optional[str] = None
metadata_name: Optional[str] = None
Loading

0 comments on commit 2da0450

Please sign in to comment.