From 31cc9cb291ac5bf46843b42e279bb804255b1c73 Mon Sep 17 00:00:00 2001 From: Andrei Vishniakov <31008759+avishniakov@users.noreply.github.com> Date: Tue, 19 Dec 2023 10:48:23 +0100 Subject: [PATCH] Fix `get_pipeline_context().model_version.get_artifact(...)` flow (#2162) * remove orphaned linkage * model aware ExtArt * rework ExtArt linkage * remove buggy test * add test case * Auto-update of E2E template * Auto-update of E2E template --------- Co-authored-by: GitHub Actions --- examples/e2e/.copier-answers.yml | 2 +- src/zenml/artifacts/external_artifact.py | 5 +- .../artifacts/external_artifact_config.py | 35 ++++---- src/zenml/model/model_version.py | 3 +- src/zenml/orchestrators/step_runner.py | 16 ++-- .../functional/steps/test_model_version.py | 79 +++++++++++++++---- 6 files changed, 94 insertions(+), 46 deletions(-) diff --git a/examples/e2e/.copier-answers.yml b/examples/e2e/.copier-answers.yml index 7a288d71456..04637068e52 100644 --- a/examples/e2e/.copier-answers.yml +++ b/examples/e2e/.copier-answers.yml @@ -1,5 +1,5 @@ # Changes here will be overwritten by Copier -_commit: 2023.12.06-4-g1e3edc6 +_commit: 2023.12.12 _src_path: gh:zenml-io/template-e2e-batch data_quality_checks: true email: '' diff --git a/src/zenml/artifacts/external_artifact.py b/src/zenml/artifacts/external_artifact.py index be65bb7c78f..cdf1e90c091 100644 --- a/src/zenml/artifacts/external_artifact.py +++ b/src/zenml/artifacts/external_artifact.py @@ -56,7 +56,9 @@ class ExternalArtifact(ExternalArtifactConfiguration): `version`, `pipeline_run_name`, or `pipeline_name` are set, the latest version of the artifact will be used. version: Version of the artifact to search. Only used when `name` is - provided. + provided. Cannot be used together with `model_version`. + model_version: The model version to search in. Only used when `name` + is provided. Cannot be used together with `version`. materializer: The materializer to use for saving the artifact value to the artifact store. Only used when `value` is provided. store_artifact_metadata: Whether metadata for the artifact should @@ -147,4 +149,5 @@ def config(self) -> ExternalArtifactConfiguration: id=self.id, name=self.name, version=self.version, + model_version=self.model_version, ) diff --git a/src/zenml/artifacts/external_artifact_config.py b/src/zenml/artifacts/external_artifact_config.py index 51ea4f2a7f9..8bdea8aa742 100644 --- a/src/zenml/artifacts/external_artifact_config.py +++ b/src/zenml/artifacts/external_artifact_config.py @@ -12,16 +12,14 @@ # or implied. See the License for the specific language governing # permissions and limitations under the License. """External artifact definition.""" -from typing import TYPE_CHECKING, Optional +from typing import Any, Dict, Optional from uuid import UUID -from pydantic import BaseModel, PrivateAttr +from pydantic import BaseModel, root_validator from zenml.logger import get_logger -from zenml.models.v2.core.artifact import ArtifactResponse - -if TYPE_CHECKING: - from zenml.model.model_version import ModelVersion +from zenml.model.model_version import ModelVersion +from zenml.models.v2.core.artifact_version import ArtifactVersionResponse logger = get_logger(__name__) @@ -35,11 +33,16 @@ class ExternalArtifactConfiguration(BaseModel): id: Optional[UUID] = None name: Optional[str] = None version: Optional[str] = None - - _model_version: Optional["ModelVersion"] = PrivateAttr(None) - - def _set_model_version(self, model_version: "ModelVersion") -> None: - self._model_version = model_version + model_version: Optional[ModelVersion] = None + + @root_validator + def _validate_all_eac(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("version", None) and values.get("model_version", None): + raise ValueError( + "Cannot provide both `version` and `model_version` when " + "creating an external artifact." + ) + return values def get_artifact_version_id(self) -> UUID: """Get the artifact. @@ -64,13 +67,13 @@ def get_artifact_version_id(self) -> UUID: response = client.get_artifact_version( self.name, version=self.version ) - elif self._model_version: - response_ = self._model_version.get_artifact(self.name) - if not isinstance(response_, ArtifactResponse): + elif self.model_version: + response_ = self.model_version.get_artifact(self.name) + if not isinstance(response_, ArtifactVersionResponse): raise RuntimeError( f"Failed to pull artifact `{self.name}` from the Model " - f"Version (name=`{self._model_version.name}`, version=" - f"`{self._model_version.version}`). Please validate the " + f"Version (name=`{self.model_version.name}`, version=" + f"`{self.model_version.version}`). Please validate the " "input and try again." ) response = response_ diff --git a/src/zenml/model/model_version.py b/src/zenml/model/model_version.py index 9bee3d71403..37bb6c12e5f 100644 --- a/src/zenml/model/model_version.py +++ b/src/zenml/model/model_version.py @@ -193,8 +193,7 @@ def _try_get_as_external_artifact( except RuntimeError: return None - ea = ExternalArtifact(name=name, version=version) - ea._set_model_version(self) + ea = ExternalArtifact(name=name, version=version, model_version=self) return ea def get_artifact( diff --git a/src/zenml/orchestrators/step_runner.py b/src/zenml/orchestrators/step_runner.py index 79b4040b2c0..1fa0470754e 100644 --- a/src/zenml/orchestrators/step_runner.py +++ b/src/zenml/orchestrators/step_runner.py @@ -731,17 +731,13 @@ def _link_pipeline_run_to_model_from_artifacts( # Add models from external artifacts for external_artifact in external_artifacts: - try: - artifact_version_id = ( - external_artifact.get_artifact_version_id() - ) - links = client.list_model_version_artifact_links( - artifact_version_id=artifact_version_id, + if external_artifact.model_version: + models.add( + ( + external_artifact.model_version.model_id, + external_artifact.model_version.id, + ) ) - for link in links: - models.add((link.model, link.model_version)) - except RuntimeError: # artifacts uploaded by value have no models - pass for model in models: client.zen_store.create_model_version_pipeline_run_link( diff --git a/tests/integration/functional/steps/test_model_version.py b/tests/integration/functional/steps/test_model_version.py index 393f4ea1717..7bf71258560 100644 --- a/tests/integration/functional/steps/test_model_version.py +++ b/tests/integration/functional/steps/test_model_version.py @@ -19,7 +19,7 @@ import pytest from typing_extensions import Annotated -from zenml import get_step_context, pipeline, step +from zenml import get_pipeline_context, get_step_context, pipeline, step from zenml.artifacts.artifact_config import ArtifactConfig from zenml.artifacts.external_artifact import ExternalArtifact from zenml.client import Client @@ -601,14 +601,6 @@ def _consumer_pipeline_with_step_context(): )(ExternalArtifact(name="output_0"), 1) -@pipeline -def _consumer_pipeline_with_artifact_context(): - _consumer_step( - ExternalArtifact(name="output_1"), - 2, - ) - - @pipeline(model_version=ModelVersion(name="step", version=ModelStages.LATEST)) def _consumer_pipeline_with_pipeline_context(): _consumer_step( @@ -629,18 +621,14 @@ def test_that_consumption_also_registers_run_in_model_version( producer_run = f"producer_run_{uuid4()}" consumer_run_1 = f"consumer_run_1_{uuid4()}" consumer_run_2 = f"consumer_run_2_{uuid4()}" - consumer_run_3 = f"consumer_run_3_{uuid4()}" _producer_pipeline.with_options( run_name=producer_run, enable_cache=False )() _consumer_pipeline_with_step_context.with_options( run_name=consumer_run_1 )() - _consumer_pipeline_with_artifact_context.with_options( - run_name=consumer_run_2 - )() _consumer_pipeline_with_pipeline_context.with_options( - run_name=consumer_run_3 + run_name=consumer_run_2 )() mv = clean_client.get_model_version( @@ -648,12 +636,11 @@ def test_that_consumption_also_registers_run_in_model_version( model_version_name_or_number_or_id=ModelStages.LATEST, ) - assert len(mv.pipeline_run_ids) == 4 + assert len(mv.pipeline_run_ids) == 3 assert {run_name for run_name in mv.pipeline_run_ids} == { producer_run, consumer_run_1, consumer_run_2, - consumer_run_3, } @@ -800,3 +787,63 @@ def _inner_pipeline(): ) == 0 ) + + +@step +def _this_step_asserts_context_with_artifact(artifact: str): + """Assert given arg with model_version number.""" + assert artifact == str(get_step_context().model_version.id) + + +@step +def _this_step_produces_output_model_version() -> ( + Annotated[str, ArtifactConfig(name="artifact")] +): + """This step produces artifact with model_version number.""" + return str(get_step_context().model_version.id) + + +def test_pipeline_context_pass_artifact_from_model_and_link_run( + clean_client: "Client", +): + """Test that ExternalArtifact from pipeline context is matched to proper version and run is linked.""" + + @pipeline(model_version=ModelVersion(name="pipeline"), enable_cache=False) + def _producer(do_promote: bool): + _this_step_produces_output_model_version() + if do_promote: + get_pipeline_context().model_version.set_stage( + ModelStages.PRODUCTION + ) + + @pipeline( + model_version=ModelVersion( + name="pipeline", version=ModelStages.PRODUCTION + ), + enable_cache=False, + ) + def _consumer(): + artifact = get_pipeline_context().model_version.get_artifact( + "artifact" + ) + _this_step_asserts_context_with_artifact(artifact) + + _producer.with_options(run_name="run_1")(True) + _producer.with_options(run_name="run_2")(False) + _consumer.with_options(run_name="run_3")() + + mv = clean_client.get_model_version( + model_name_or_id="pipeline", + model_version_name_or_number_or_id=ModelStages.LATEST, + ) + + assert len(mv.pipeline_run_ids) == 1 + assert {run_name for run_name in mv.pipeline_run_ids} == {"run_2"} + + mv = clean_client.get_model_version( + model_name_or_id="pipeline", + model_version_name_or_number_or_id=ModelStages.PRODUCTION, + ) + + assert len(mv.pipeline_run_ids) == 2 + assert {run_name for run_name in mv.pipeline_run_ids} == {"run_1", "run_3"}