diff --git a/.github/workflows/update-templates-to-examples.yml b/.github/workflows/update-templates-to-examples.yml index 2c96c17925b..7e4f3bb2133 100644 --- a/.github/workflows/update-templates-to-examples.yml +++ b/.github/workflows/update-templates-to-examples.yml @@ -95,7 +95,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | - github.issues.createComment({ + github.rest.issues.createComment({ issue_number: ${{ github.event.pull_request.number }}, owner: context.repo.owner, repo: context.repo.repo, @@ -165,7 +165,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} script: | - github.issues.createComment({ + github.rest.issues.createComment({ issue_number: ${{ github.event.pull_request.number }}, owner: context.repo.owner, repo: context.repo.repo, @@ -237,7 +237,7 @@ jobs: with: github-token: ${{ secrets.GITHUB_TOKEN }} script: |- - github.issues.createComment({ + github.rest.issues.createComment({ issue_number: ${{ github.event.pull_request.number }}, owner: context.repo.owner, repo: context.repo.repo, diff --git a/examples/e2e/steps/deployment/deployment_deploy.py b/examples/e2e/steps/deployment/deployment_deploy.py index b40d5119e78..2ca8d11b474 100644 --- a/examples/e2e/steps/deployment/deployment_deploy.py +++ b/examples/e2e/steps/deployment/deployment_deploy.py @@ -65,9 +65,9 @@ def deployment_deploy() -> ( # deploy predictor service deployment_service = mlflow_model_registry_deployer_step.entrypoint( registry_model_name=model_version.name, - registry_model_version=model_version.metadata[ + registry_model_version=model_version.run_metadata[ "model_registry_version" - ], + ].value, replace_existing=True, ) else: diff --git a/examples/e2e/steps/promotion/promote_with_metric_compare.py b/examples/e2e/steps/promotion/promote_with_metric_compare.py index 9cef476a880..7d0579b33cd 100644 --- a/examples/e2e/steps/promotion/promote_with_metric_compare.py +++ b/examples/e2e/steps/promotion/promote_with_metric_compare.py @@ -90,17 +90,17 @@ def promote_with_metric_compare( logger.info(f"Current model version was promoted to '{target_env}'.") # Promote in Model Registry - latest_version_model_registry_number = latest_version.metadata[ + latest_version_model_registry_number = latest_version.run_metadata[ "model_registry_version" - ] + ].value if current_version_number is None: current_version_model_registry_number = ( latest_version_model_registry_number ) else: - current_version_model_registry_number = current_version.metadata[ - "model_registry_version" - ] + current_version_model_registry_number = ( + current_version.run_metadata["model_registry_version"].value + ) promote_in_model_registry( latest_version=latest_version_model_registry_number, current_version=current_version_model_registry_number, @@ -109,7 +109,9 @@ def promote_with_metric_compare( ) promoted_version = latest_version_model_registry_number else: - promoted_version = current_version.metadata["model_registry_version"] + promoted_version = current_version.run_metadata[ + "model_registry_version" + ].value logger.info( f"Current model version in `{target_env}` is `{promoted_version}` registered in Model Registry" diff --git a/src/zenml/__init__.py b/src/zenml/__init__.py index 3cdd2d95765..aab2d7e2546 100644 --- a/src/zenml/__init__.py +++ b/src/zenml/__init__.py @@ -46,6 +46,7 @@ from zenml.artifacts.artifact_config import ArtifactConfig from zenml.artifacts.external_artifact import ExternalArtifact from zenml.model.model_version import ModelVersion +from zenml.model.utils import log_model_version_metadata from zenml.new.pipelines.pipeline_context import get_pipeline_context from zenml.new.pipelines.pipeline_decorator import pipeline from zenml.new.steps.step_decorator import step diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index da0f2410f29..b8352451e6b 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -60,7 +60,7 @@ def _model_version_to_print( "number": model_version.number, "description": model_version.description, "stage": model_version.stage, - "metadata": model_version.to_model_version().metadata, + "metadata": model_version.to_model_version().run_metadata, "tags": [t.name for t in model_version.tags], "data_artifacts_count": len(model_version.data_artifact_ids), "model_artifacts_count": len(model_version.model_artifact_ids), diff --git a/src/zenml/config/step_configurations.py b/src/zenml/config/step_configurations.py index 6ac2dc4e688..7a4f75b2308 100644 --- a/src/zenml/config/step_configurations.py +++ b/src/zenml/config/step_configurations.py @@ -34,6 +34,7 @@ from zenml.config.source import Source, convert_source_validator from zenml.config.strict_base_model import StrictBaseModel from zenml.logger import get_logger +from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model_version import ModelVersion from zenml.utils import deprecation_utils @@ -152,6 +153,7 @@ class PartialStepConfiguration(StepConfigurationUpdate): name: str caching_parameters: Mapping[str, Any] = {} external_input_artifacts: Mapping[str, ExternalArtifactConfiguration] = {} + model_artifacts_or_metadata: Mapping[str, ModelVersionDataLazyLoader] = {} outputs: Mapping[str, PartialArtifactConfiguration] = {} # Override the deprecation validator as we do not want to deprecate the diff --git a/src/zenml/metadata/lazy_load.py b/src/zenml/metadata/lazy_load.py new file mode 100644 index 00000000000..4b180d918b9 --- /dev/null +++ b/src/zenml/metadata/lazy_load.py @@ -0,0 +1,64 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Run Metadata Lazy Loader definition.""" + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from zenml.model.model_version import ModelVersion + from zenml.models import RunMetadataResponse + + +class RunMetadataLazyGetter: + """Run Metadata Lazy Getter helper class. + + It serves the purpose to feed back to the user the metadata + lazy loader wrapper for any given key, if called inside a pipeline + design time context. + """ + + def __init__( + self, + _lazy_load_model_version: "ModelVersion", + _lazy_load_artifact_name: Optional[str], + _lazy_load_artifact_version: Optional[str], + ): + """Initialize a RunMetadataLazyGetter. + + Args: + _lazy_load_model_version: The model version. + _lazy_load_artifact_name: The artifact name. + _lazy_load_artifact_version: The artifact version. + """ + self._lazy_load_model_version = _lazy_load_model_version + self._lazy_load_artifact_name = _lazy_load_artifact_name + self._lazy_load_artifact_version = _lazy_load_artifact_version + + def __getitem__(self, key: str) -> "RunMetadataResponse": + """Get the metadata for the given key. + + Args: + key: The metadata key. + + Returns: + The metadata lazy loader wrapper for the given key. + """ + from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse + + return LazyRunMetadataResponse( + _lazy_load_model_version=self._lazy_load_model_version, + _lazy_load_artifact_name=self._lazy_load_artifact_name, + _lazy_load_artifact_version=self._lazy_load_artifact_version, + _lazy_load_metadata_name=key, + ) diff --git a/src/zenml/metadata/metadata_types.py b/src/zenml/metadata/metadata_types.py index 0cf072c1d76..1976c56c11c 100644 --- a/src/zenml/metadata/metadata_types.py +++ b/src/zenml/metadata/metadata_types.py @@ -13,10 +13,13 @@ # permissions and limitations under the License. """Custom types that can be used as metadata of ZenML artifacts.""" -from typing import Any, Dict, List, Set, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Set, Tuple, Union from zenml.utils.enum_utils import StrEnum +if TYPE_CHECKING: + pass + class Uri(str): """Special string class to indicate a URI.""" diff --git a/src/zenml/model/lazy_load.py b/src/zenml/model/lazy_load.py new file mode 100644 index 00000000000..2df22dc5b4a --- /dev/null +++ b/src/zenml/model/lazy_load.py @@ -0,0 +1,34 @@ +# Copyright (c) ZenML GmbH 2024. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Model Version Data Lazy Loader definition.""" + +from typing import Optional + +from pydantic import BaseModel + +from zenml.model.model_version import ModelVersion + + +class ModelVersionDataLazyLoader(BaseModel): + """Model Version Data Lazy Loader helper class. + + It helps the inner codes to fetch proper artifact, + model version metadata or artifact metadata from the + model version during runtime time of the step. + """ + + model_version: ModelVersion + artifact_name: Optional[str] = None + artifact_version: Optional[str] = None + metadata_name: Optional[str] = None diff --git a/src/zenml/model/model_version.py b/src/zenml/model/model_version.py index b1afb31fa84..4c40e9319f5 100644 --- a/src/zenml/model/model_version.py +++ b/src/zenml/model/model_version.py @@ -30,13 +30,13 @@ from zenml.logger import get_logger if TYPE_CHECKING: - from zenml import ExternalArtifact from zenml.metadata.metadata_types import MetadataType from zenml.models import ( ArtifactVersionResponse, ModelResponse, ModelVersionResponse, PipelineRunResponse, + RunMetadataResponse, ) logger = get_logger(__name__) @@ -181,26 +181,11 @@ def load_artifact(self, name: str, version: Optional[str] = None) -> Any: return load_artifact(artifact.id, str(artifact.version)) - def _try_get_as_external_artifact( - self, - name: str, - version: Optional[str] = None, - ) -> Optional["ExternalArtifact"]: - from zenml import ExternalArtifact, get_pipeline_context - - try: - get_pipeline_context() - except RuntimeError: - return None - - ea = ExternalArtifact(name=name, version=version, model_version=self) - return ea - def get_artifact( self, name: str, version: Optional[str] = None, - ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]: + ) -> Optional["ArtifactVersionResponse"]: """Get the artifact linked to this model version. Args: @@ -208,11 +193,11 @@ def get_artifact( version: The version of the artifact to retrieve (None for latest/non-versioned) Returns: - Inside pipeline context: ExternalArtifact object as a lazy loader - Outside of pipeline context: Specific version of the artifact or None + Specific version of the artifact or placeholder in the design time of the pipeline. """ - if response := self._try_get_as_external_artifact(name, version): - return response + if lazy := self._lazy_artifact_get(name, version): + return lazy + return self._get_or_create_model_version().get_artifact( name=name, version=version, @@ -222,7 +207,7 @@ def get_model_artifact( self, name: str, version: Optional[str] = None, - ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]: + ) -> Optional["ArtifactVersionResponse"]: """Get the model artifact linked to this model version. Args: @@ -230,11 +215,11 @@ def get_model_artifact( version: The version of the model artifact to retrieve (None for latest/non-versioned) Returns: - Inside pipeline context: ExternalArtifact object as a lazy loader - Outside of pipeline context: Specific version of the model artifact or None + Specific version of the model artifact or placeholder in the design time of the pipeline. """ - if response := self._try_get_as_external_artifact(name, version): - return response + if lazy := self._lazy_artifact_get(name, version): + return lazy + return self._get_or_create_model_version().get_model_artifact( name=name, version=version, @@ -244,7 +229,7 @@ def get_data_artifact( self, name: str, version: Optional[str] = None, - ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]: + ) -> Optional["ArtifactVersionResponse"]: """Get the data artifact linked to this model version. Args: @@ -252,11 +237,11 @@ def get_data_artifact( version: The version of the data artifact to retrieve (None for latest/non-versioned) Returns: - Inside pipeline context: ExternalArtifact object as a lazy loader - Outside of pipeline context: Specific version of the data artifact or None + Specific version of the data artifact or placeholder in the design time of the pipeline. """ - if response := self._try_get_as_external_artifact(name, version): - return response + if lazy := self._lazy_artifact_get(name, version): + return lazy + return self._get_or_create_model_version().get_data_artifact( name=name, version=version, @@ -266,7 +251,7 @@ def get_deployment_artifact( self, name: str, version: Optional[str] = None, - ) -> Optional[Union["ArtifactVersionResponse", "ExternalArtifact"]]: + ) -> Optional["ArtifactVersionResponse"]: """Get the deployment artifact linked to this model version. Args: @@ -274,11 +259,11 @@ def get_deployment_artifact( version: The version of the deployment artifact to retrieve (None for latest/non-versioned) Returns: - Inside pipeline context: ExternalArtifact object as a lazy loader - Outside of pipeline context: Specific version of the deployment artifact or None + Specific version of the deployment artifact or placeholder in the design time of the pipeline. """ - if response := self._try_get_as_external_artifact(name, version): - return response + if lazy := self._lazy_artifact_get(name, version): + return lazy + return self._get_or_create_model_version().get_deployment_artifact( name=name, version=version, @@ -327,24 +312,37 @@ def log_metadata( ) @property - def metadata(self) -> Dict[str, "MetadataType"]: - """Get model version metadata. + def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + """Get model version run metadata. Returns: - The model version metadata. + The model version run metadata. Raises: - RuntimeError: If the model version metadata cannot be fetched. + RuntimeError: If the model version run metadata cannot be fetched. """ + from zenml.metadata.lazy_load import RunMetadataLazyGetter + from zenml.new.pipelines.pipeline_context import ( + get_pipeline_context, + ) + + try: + get_pipeline_context() + # avoid exposing too much of internal details by keeping the return type + return RunMetadataLazyGetter( # type: ignore[return-value] + self, + None, + None, + ) + except RuntimeError: + pass + response = self._get_or_create_model_version(hydrate=True) if response.run_metadata is None: raise RuntimeError( "Failed to fetch metadata of this model version." ) - return { - name: response.value - for name, response in response.run_metadata.items() - } + return response.run_metadata def delete_artifact( self, @@ -418,6 +416,30 @@ class Config: smart_union = True + def _lazy_artifact_get( + self, + name: str, + version: Optional[str] = None, + ) -> Optional["ArtifactVersionResponse"]: + from zenml import get_pipeline_context + from zenml.models.v2.core.artifact_version import ( + LazyArtifactVersionResponse, + ) + + try: + get_pipeline_context() + return LazyArtifactVersionResponse( + _lazy_load_name=name, + _lazy_load_version=version, + _lazy_load_model_version=ModelVersion( + name=self.name, version=self.version or self.number + ), + ) + except RuntimeError: + pass + + return None + def __eq__(self, other: object) -> bool: """Check two ModelVersions for equality. diff --git a/src/zenml/models/v2/core/artifact_version.py b/src/zenml/models/v2/core/artifact_version.py index 5b6820b18d4..e39dc2fa530 100644 --- a/src/zenml/models/v2/core/artifact_version.py +++ b/src/zenml/models/v2/core/artifact_version.py @@ -44,6 +44,7 @@ if TYPE_CHECKING: from sqlalchemy.sql.elements import BinaryExpression, BooleanClauseList + from zenml.model.model_version import ModelVersion from zenml.models.v2.core.artifact_visualization import ( ArtifactVisualizationRequest, ArtifactVisualizationResponse, @@ -461,3 +462,52 @@ def get_custom_filters( custom_filters.append(custom_name_filter) return custom_filters + + +# -------------------- Lazy Loader -------------------- + + +class LazyArtifactVersionResponse(ArtifactVersionResponse): + """Lazy artifact version response. + + Used if the artifact version is accessed from the model in + a pipeline context available only during pipeline compilation. + """ + + id: Optional[UUID] = None # type: ignore[assignment] + _lazy_load_name: Optional[str] = None + _lazy_load_version: Optional[str] = None + _lazy_load_model_version: "ModelVersion" + + def get_body(self) -> None: # type: ignore[override] + """Protects from misuse of the lazy loader. + + Raises: + RuntimeError: always + """ + raise RuntimeError("Cannot access artifact body before pipeline runs.") + + def get_metadata(self) -> None: # type: ignore[override] + """Protects from misuse of the lazy loader. + + Raises: + RuntimeError: always + """ + raise RuntimeError( + "Cannot access artifact metadata before pipeline runs." + ) + + @property + def run_metadata(self) -> Dict[str, "RunMetadataResponse"]: + """The `metadata` property in lazy loading mode. + + Returns: + getter of lazy responses for internal use. + """ + from zenml.metadata.lazy_load import RunMetadataLazyGetter + + return RunMetadataLazyGetter( # type: ignore[return-value] + self._lazy_load_model_version, + self._lazy_load_name, + self._lazy_load_version, + ) diff --git a/src/zenml/models/v2/core/run_metadata.py b/src/zenml/models/v2/core/run_metadata.py index dc8ac64ec74..e8b3045e554 100644 --- a/src/zenml/models/v2/core/run_metadata.py +++ b/src/zenml/models/v2/core/run_metadata.py @@ -13,7 +13,7 @@ # permissions and limitations under the License. """Models representing run metadata.""" -from typing import Dict, Optional, Union +from typing import TYPE_CHECKING, Dict, Optional, Union from uuid import UUID from pydantic import Field @@ -29,6 +29,9 @@ WorkspaceScopedResponseMetadata, ) +if TYPE_CHECKING: + from zenml.model.model_version import ModelVersion + # ------------------ Request Model ------------------ @@ -184,3 +187,40 @@ class RunMetadataFilter(WorkspaceScopedFilter): stack_component_id: Optional[Union[str, UUID]] = None key: Optional[str] = None type: Optional[Union[str, MetadataTypeEnum]] = None + + +# -------------------- Lazy Loader -------------------- + + +class LazyRunMetadataResponse(RunMetadataResponse): + """Lazy run metadata response. + + Used if the run metadata is accessed from the model in + a pipeline context available only during pipeline compilation. + """ + + id: Optional[UUID] = None # type: ignore[assignment] + _lazy_load_artifact_name: Optional[str] = None + _lazy_load_artifact_version: Optional[str] = None + _lazy_load_metadata_name: Optional[str] = None + _lazy_load_model_version: "ModelVersion" + + def get_body(self) -> None: # type: ignore[override] + """Protects from misuse of the lazy loader. + + Raises: + RuntimeError: always + """ + raise RuntimeError( + "Cannot access run metadata body before pipeline runs." + ) + + def get_metadata(self) -> None: # type: ignore[override] + """Protects from misuse of the lazy loader. + + Raises: + RuntimeError: always + """ + raise RuntimeError( + "Cannot access run metadata metadata before pipeline runs." + ) diff --git a/src/zenml/new/pipelines/pipeline.py b/src/zenml/new/pipelines/pipeline.py index b5006ad1ae8..d39a4640476 100644 --- a/src/zenml/new/pipelines/pipeline.py +++ b/src/zenml/new/pipelines/pipeline.py @@ -92,6 +92,7 @@ from zenml.artifacts.external_artifact import ExternalArtifact from zenml.config.base_settings import SettingsOrDict from zenml.config.source import Source + from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model_version import ModelVersion StepConfigurationUpdateOrDict = Union[ @@ -1259,6 +1260,7 @@ def add_step_invocation( step: "BaseStep", input_artifacts: Dict[str, StepArtifact], external_artifacts: Dict[str, "ExternalArtifact"], + model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], parameters: Dict[str, Any], default_parameters: Dict[str, Any], upstream_steps: Set[str], @@ -1271,6 +1273,8 @@ def add_step_invocation( step: The step for which to add an invocation. input_artifacts: The input artifacts for the invocation. external_artifacts: The external artifacts for the invocation. + model_artifacts_or_metadata: The model artifacts or metadata for + the invocation. parameters: The parameters for the invocation. default_parameters: The default parameters for the invocation. upstream_steps: The upstream steps for the invocation. @@ -1307,6 +1311,7 @@ def add_step_invocation( step=step, input_artifacts=input_artifacts, external_artifacts=external_artifacts, + model_artifacts_or_metadata=model_artifacts_or_metadata, parameters=parameters, default_parameters=default_parameters, upstream_steps=upstream_steps, diff --git a/src/zenml/new/pipelines/pipeline_context.py b/src/zenml/new/pipelines/pipeline_context.py index 1a977bcc646..e8ac864e68a 100644 --- a/src/zenml/new/pipelines/pipeline_context.py +++ b/src/zenml/new/pipelines/pipeline_context.py @@ -20,10 +20,10 @@ def get_pipeline_context() -> "PipelineContext": - """Get the context of the currently composing pipeline. + """Get the context of the current pipeline. Returns: - The context of the currently composing pipeline. + The context of the current pipeline. Raises: RuntimeError: If no active pipeline is found. @@ -50,7 +50,7 @@ def get_pipeline_context() -> "PipelineContext": class PipelineContext: - """Provides pipeline configuration during it's composition. + """Provides pipeline configuration context. Usage example: @@ -90,7 +90,7 @@ def my_pipeline(): """ def __init__(self, pipeline_configuration: "PipelineConfiguration"): - """Initialize the context of the currently composing pipeline. + """Initialize the context of the current pipeline. Args: pipeline_configuration: The configuration of the pipeline derived diff --git a/src/zenml/orchestrators/input_utils.py b/src/zenml/orchestrators/input_utils.py index f7f4bfff22a..8e5a097f7ad 100644 --- a/src/zenml/orchestrators/input_utils.py +++ b/src/zenml/orchestrators/input_utils.py @@ -39,6 +39,8 @@ def resolve_step_inputs( Raises: InputResolutionError: If input resolving failed due to a missing step or output. + ValueError: If object from model version passed into a step cannot be + resolved in runtime due to missing object. Returns: The IDs of the input artifact versions and the IDs of parent steps of @@ -81,6 +83,49 @@ def resolve_step_inputs( artifact_version_id ) + for name, config_ in step.config.model_artifacts_or_metadata.items(): + issue_found = False + try: + if config_.metadata_name is None and config_.artifact_name: + if artifact_ := config_.model_version.get_artifact( + config_.artifact_name, config_.artifact_version + ): + input_artifacts[name] = artifact_ + else: + issue_found = True + elif config_.artifact_name is None and config_.metadata_name: + # metadata values should go directly in parameters, as primitive types + step.config.parameters[ + name + ] = config_.model_version.run_metadata[ + config_.metadata_name + ].value + elif config_.metadata_name and config_.artifact_name: + # metadata values should go directly in parameters, as primitive types + if artifact_ := config_.model_version.get_artifact( + config_.artifact_name, config_.artifact_version + ): + step.config.parameters[name] = artifact_.run_metadata[ + config_.metadata_name + ].value + else: + issue_found = True + else: + issue_found = True + except KeyError: + issue_found = True + + if issue_found: + raise ValueError( + "Cannot fetch requested information from model " + f"`{config_.model_version.name}` version " + f"`{config_.model_version.version}` given artifact " + f"`{config_.artifact_name}`, artifact version " + f"`{config_.artifact_version}`, and metadata " + f"key `{config_.metadata_name}` passed into " + f"the step `{step.config.name}`." + ) + parent_step_ids = [ current_run_steps[upstream_step].id for upstream_step in step.spec.upstream_steps diff --git a/src/zenml/steps/base_step.py b/src/zenml/steps/base_step.py index 23eac3367c6..1106797903c 100644 --- a/src/zenml/steps/base_step.py +++ b/src/zenml/steps/base_step.py @@ -70,6 +70,7 @@ StepConfiguration, StepConfigurationUpdate, ) + from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.model.model_version import ModelVersion ParametersOrDict = Union["BaseParameters", Dict[str, Any]] @@ -442,6 +443,7 @@ def _parse_call_args( ) -> Tuple[ Dict[str, "StepArtifact"], Dict[str, "ExternalArtifact"], + Dict[str, "ModelVersionDataLazyLoader"], Dict[str, Any], Dict[str, Any], ]: @@ -455,9 +457,14 @@ def _parse_call_args( StepInterfaceError: If invalid function arguments were passed. Returns: - The artifacts, external artifacts and parameters for the step. + The artifacts, external artifacts, model version artifacts/metadata and parameters for the step. """ from zenml.artifacts.external_artifact import ExternalArtifact + from zenml.model.lazy_load import ModelVersionDataLazyLoader + from zenml.models.v2.core.artifact_version import ( + LazyArtifactVersionResponse, + ) + from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse signature = get_step_entrypoint_signature(step=self) @@ -470,6 +477,7 @@ def _parse_call_args( artifacts = {} external_artifacts = {} + model_artifacts_or_metadata = {} parameters = {} default_parameters = {} @@ -495,6 +503,20 @@ def _parse_call_args( "steps. Future releases will introduce hashing of " "artifacts which will improve this behavior." ) + elif isinstance(value, LazyArtifactVersionResponse): + model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader( + model_version=value._lazy_load_model_version, + artifact_name=value._lazy_load_name, + artifact_version=value._lazy_load_version, + metadata_name=None, + ) + elif isinstance(value, LazyRunMetadataResponse): + model_artifacts_or_metadata[key] = ModelVersionDataLazyLoader( + model_version=value._lazy_load_model_version, + artifact_name=value._lazy_load_artifact_name, + artifact_version=value._lazy_load_artifact_version, + metadata_name=value._lazy_load_metadata_name, + ) else: parameters[key] = value @@ -510,11 +532,18 @@ def _parse_call_args( if ( key not in artifacts and key not in external_artifacts + and key not in model_artifacts_or_metadata and key not in self.configuration.parameters ): default_parameters[key] = value - return artifacts, external_artifacts, parameters, default_parameters + return ( + artifacts, + external_artifacts, + model_artifacts_or_metadata, + parameters, + default_parameters, + ) def __call__( self, @@ -549,6 +578,7 @@ def __call__( ( input_artifacts, external_artifacts, + model_artifacts_or_metadata, parameters, default_parameters, ) = self._parse_call_args(*args, **kwargs) @@ -565,6 +595,7 @@ def __call__( step=self, input_artifacts=input_artifacts, external_artifacts=external_artifacts, + model_artifacts_or_metadata=model_artifacts_or_metadata, parameters=parameters, default_parameters=default_parameters, upstream_steps=upstream_steps, @@ -993,6 +1024,7 @@ def _validate_inputs( self, input_artifacts: Dict[str, "StepArtifact"], external_artifacts: Dict[str, "ExternalArtifactConfiguration"], + model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], ) -> None: """Validates the step inputs. @@ -1002,6 +1034,7 @@ def _validate_inputs( Args: input_artifacts: The input artifacts. external_artifacts: The external input artifacts. + model_artifacts_or_metadata: The model artifacts or metadata. Raises: StepInterfaceError: If an entrypoint input is missing. @@ -1011,6 +1044,7 @@ def _validate_inputs( key in input_artifacts or key in self.configuration.parameters or key in external_artifacts + or key in model_artifacts_or_metadata ): continue raise StepInterfaceError(f"Missing entrypoint input {key}.") @@ -1019,6 +1053,7 @@ def _finalize_configuration( self, input_artifacts: Dict[str, "StepArtifact"], external_artifacts: Dict[str, "ExternalArtifactConfiguration"], + model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], ) -> "StepConfiguration": """Finalizes the configuration after the step was called. @@ -1030,6 +1065,8 @@ def _finalize_configuration( Args: input_artifacts: The input artifacts of this step. external_artifacts: The external artifacts of this step. + model_artifacts_or_metadata: The model artifacts or metadata of + this step. Returns: The finalized step configuration. @@ -1102,6 +1139,7 @@ def _finalize_configuration( self._validate_inputs( input_artifacts=input_artifacts, external_artifacts=external_artifacts, + model_artifacts_or_metadata=model_artifacts_or_metadata, ) values = dict_utils.remove_none_values({"outputs": outputs or None}) @@ -1112,6 +1150,7 @@ def _finalize_configuration( update={ "caching_parameters": self.caching_parameters, "external_input_artifacts": external_artifacts, + "model_artifacts_or_metadata": model_artifacts_or_metadata, } ) diff --git a/src/zenml/steps/entrypoint_function_utils.py b/src/zenml/steps/entrypoint_function_utils.py index 52b5173216e..b285acaa51e 100644 --- a/src/zenml/steps/entrypoint_function_utils.py +++ b/src/zenml/steps/entrypoint_function_utils.py @@ -157,6 +157,10 @@ def validate_input(self, key: str, value: Any) -> None: from zenml.artifacts.unmaterialized_artifact import ( UnmaterializedArtifact, ) + from zenml.models import ( + ArtifactVersionResponse, + RunMetadataResponse, + ) if key not in self.inputs: raise KeyError( @@ -165,7 +169,15 @@ def validate_input(self, key: str, value: Any) -> None: parameter = self.inputs[key] - if isinstance(value, (StepArtifact, ExternalArtifact)): + if isinstance( + value, + ( + StepArtifact, + ExternalArtifact, + ArtifactVersionResponse, + RunMetadataResponse, + ), + ): # If we were to do any type validation for artifacts here, we # would not be able to leverage pydantics type coercion (e.g. # providing an `int` artifact for a `float` input) diff --git a/src/zenml/steps/step_invocation.py b/src/zenml/steps/step_invocation.py index 4e3583514a8..b83e252f720 100644 --- a/src/zenml/steps/step_invocation.py +++ b/src/zenml/steps/step_invocation.py @@ -17,6 +17,7 @@ if TYPE_CHECKING: from zenml.artifacts.external_artifact import ExternalArtifact from zenml.config.step_configurations import StepConfiguration + from zenml.model.lazy_load import ModelVersionDataLazyLoader from zenml.new.pipelines.pipeline import Pipeline from zenml.steps import BaseStep from zenml.steps.entrypoint_function_utils import StepArtifact @@ -31,6 +32,7 @@ def __init__( step: "BaseStep", input_artifacts: Dict[str, "StepArtifact"], external_artifacts: Dict[str, "ExternalArtifact"], + model_artifacts_or_metadata: Dict[str, "ModelVersionDataLazyLoader"], parameters: Dict[str, Any], default_parameters: Dict[str, Any], upstream_steps: Set[str], @@ -43,6 +45,8 @@ def __init__( step: The step that is represented by the invocation. input_artifacts: The input artifacts for the invocation. external_artifacts: The external artifacts for the invocation. + model_artifacts_or_metadata: The model artifacts or metadata for + the invocation. parameters: The parameters for the invocation. default_parameters: The default parameters for the invocation. upstream_steps: The upstream steps for the invocation. @@ -52,6 +56,7 @@ def __init__( self.step = step self.input_artifacts = input_artifacts self.external_artifacts = external_artifacts + self.model_artifacts_or_metadata = model_artifacts_or_metadata self.parameters = parameters self.default_parameters = default_parameters self.invocation_upstream_steps = upstream_steps @@ -151,4 +156,5 @@ def finalize(self, parameters_to_ignore: Set[str]) -> "StepConfiguration": return self.step._finalize_configuration( input_artifacts=self.input_artifacts, external_artifacts=external_artifacts, + model_artifacts_or_metadata=self.model_artifacts_or_metadata, ) diff --git a/tests/integration/functional/model/test_model_version.py b/tests/integration/functional/model/test_model_version.py index f0eaae0c926..6e1e48a8a5f 100644 --- a/tests/integration/functional/model/test_model_version.py +++ b/tests/integration/functional/model/test_model_version.py @@ -65,7 +65,7 @@ def __exit__(self, exc_type, exc_value, exc_traceback): def step_metadata_logging_functional(): """Functional logging using implicit ModelVersion from context.""" log_model_metadata({"foo": "bar"}) - assert get_step_context().model_version.metadata["foo"] == "bar" + assert get_step_context().model_version.run_metadata["foo"].value == "bar" @step @@ -346,14 +346,14 @@ def test_metadata_logging(self, clean_client: "Client"): ) mv.log_metadata({"foo": "bar"}) - assert len(mv.metadata) == 1 - assert mv.metadata["foo"] == "bar" + assert len(mv.run_metadata) == 1 + assert mv.run_metadata["foo"].value == "bar" mv.log_metadata({"bar": "foo"}) - assert len(mv.metadata) == 2 - assert mv.metadata["foo"] == "bar" - assert mv.metadata["bar"] == "foo" + assert len(mv.run_metadata) == 2 + assert mv.run_metadata["foo"].value == "bar" + assert mv.run_metadata["bar"].value == "foo" def test_metadata_logging_functional(self, clean_client: "Client"): """Test that model version can be used to track metadata from function.""" @@ -367,8 +367,8 @@ def test_metadata_logging_functional(self, clean_client: "Client"): {"foo": "bar"}, model_name=mv.name, model_version=mv.number ) - assert len(mv.metadata) == 1 - assert mv.metadata["foo"] == "bar" + assert len(mv.run_metadata) == 1 + assert mv.run_metadata["foo"].value == "bar" with pytest.raises(ValueError): log_model_metadata({"foo": "bar"}) @@ -377,9 +377,9 @@ def test_metadata_logging_functional(self, clean_client: "Client"): {"bar": "foo"}, model_name=mv.name, model_version="latest" ) - assert len(mv.metadata) == 2 - assert mv.metadata["foo"] == "bar" - assert mv.metadata["bar"] == "foo" + assert len(mv.run_metadata) == 2 + assert mv.run_metadata["foo"].value == "bar" + assert mv.run_metadata["bar"].value == "foo" def test_metadata_logging_in_steps(self, clean_client: "Client"): """Test that model version can be used to track metadata from function in steps.""" @@ -396,8 +396,8 @@ def my_pipeline(): my_pipeline() mv = ModelVersion(name=MODEL_NAME, version="latest") - assert len(mv.metadata) == 1 - assert mv.metadata["foo"] == "bar" + assert len(mv.run_metadata) == 1 + assert mv.run_metadata["foo"].value == "bar" @pytest.mark.parametrize("delete_artifacts", [False, True]) def test_deletion_of_links( diff --git a/tests/integration/functional/pipelines/test_pipeline_context.py b/tests/integration/functional/pipelines/test_pipeline_context.py index 9f805f073da..d6c628b5a62 100644 --- a/tests/integration/functional/pipelines/test_pipeline_context.py +++ b/tests/integration/functional/pipelines/test_pipeline_context.py @@ -1,5 +1,6 @@ import pytest import yaml +from typing_extensions import Annotated from zenml import ( ModelVersion, @@ -8,7 +9,9 @@ pipeline, step, ) +from zenml.artifacts.utils import log_artifact_metadata from zenml.client import Client +from zenml.model.utils import log_model_version_metadata @step @@ -33,13 +36,9 @@ def assert_pipeline_context_in_pipeline(): context = get_pipeline_context() assert ( context.name == "assert_pipeline_context_in_pipeline" - ), "Not accessible inside composition of pipeline" - assert ( - context.enable_cache is False - ), "Not accessible inside composition of pipeline" - assert context.extra == { - "foo": "bar" - }, "Not accessible inside composition of pipeline" + ), "Not accessible inside pipeline" + assert context.enable_cache is False, "Not accessible inside pipeline" + assert context.extra == {"foo": "bar"}, "Not accessible inside pipeline" assert_pipeline_context_in_step() @@ -97,3 +96,49 @@ def test_that_argument_as_get_artifact_of_model_version_in_pipeline_context_fail producer_pipe(False) with pytest.raises(RuntimeError): consumer_pipe() + + +@step +def producer() -> Annotated[str, "bar"]: + """Produce artifact with metadata and attach metadata to model version.""" + ver = get_step_context().model_version.version + log_model_version_metadata(metadata={"foobar": "model_meta_" + ver}) + log_artifact_metadata(metadata={"foobar": "artifact_meta_" + ver}) + return "artifact_data_" + ver + + +@step +def asserter(artifact: str, artifact_metadata: str, model_metadata: str): + """Assert that passed in values are loaded in lazy mode. + + They do not exists before actual run of the pipeline. + """ + ver = get_step_context().model_version.version + assert artifact == "artifact_data_" + ver + assert artifact_metadata == "artifact_meta_" + ver + assert model_metadata == "model_meta_" + ver + + +def test_pipeline_context_can_load_model_artifacts_and_metadata_in_lazy_mode( + clean_client: "Client", +): + """Tests that user can load model artifacts and metadata in lazy mode in pipeline codes.""" + + model_name = "foo" + + @pipeline(model_version=ModelVersion(name=model_name), enable_cache=False) + def dummy(): + producer() + with pytest.raises(KeyError): + clean_client.get_model(model_name) + with pytest.raises(KeyError): + clean_client.get_artifact_version("bar") + model_version = get_pipeline_context().model_version + artifact = model_version.get_artifact("bar") + artifact_metadata = artifact.run_metadata["foobar"] + model_metadata = model_version.run_metadata["foobar"] + asserter( + artifact, artifact_metadata, model_metadata, after=["producer"] + ) + + dummy() diff --git a/tests/unit/materializers/test_materializer_registry.py b/tests/unit/materializers/test_materializer_registry.py index 0c36e6012f7..4afc0e27031 100644 --- a/tests/unit/materializers/test_materializer_registry.py +++ b/tests/unit/materializers/test_materializer_registry.py @@ -81,7 +81,9 @@ def some_step() -> MyConflictingType: step_instance = some_step() with does_not_raise(): step_instance._finalize_configuration( - input_artifacts={}, external_artifacts={} + input_artifacts={}, + external_artifacts={}, + model_artifacts_or_metadata={}, ) # The step uses the materializer registered for the earliest class in the diff --git a/tests/unit/orchestrators/test_cache_utils.py b/tests/unit/orchestrators/test_cache_utils.py index 1d879d73cf2..40264de698c 100644 --- a/tests/unit/orchestrators/test_cache_utils.py +++ b/tests/unit/orchestrators/test_cache_utils.py @@ -37,6 +37,7 @@ def _compile_step(step: BaseStep) -> Step: step=step, input_artifacts={}, external_artifacts={}, + model_artifacts_or_metadata={}, parameters={}, default_parameters={}, upstream_steps=set(), diff --git a/tests/unit/steps/test_base_step.py b/tests/unit/steps/test_base_step.py index f40fbe729c1..c0d00681304 100644 --- a/tests/unit/steps/test_base_step.py +++ b/tests/unit/steps/test_base_step.py @@ -325,7 +325,7 @@ def some_step(params: ParamsWithDefaultValues) -> None: # don't pass the config when initializing the step step_instance = some_step() - step_instance._finalize_configuration({}, {}) + step_instance._finalize_configuration({}, {}, {}) assert ( step_instance.configuration.parameters["params"]["some_parameter"] == 1