Skip to content

Commit

Permalink
Fix get_pipeline_context().model_version.get_artifact(...) flow (#2162
Browse files Browse the repository at this point in the history
)

* remove orphaned linkage

* model aware ExtArt

* rework ExtArt linkage

* remove buggy test

* add test case

* Auto-update of E2E template

* Auto-update of E2E template

---------

Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
avishniakov and actions-user authored Dec 19, 2023
1 parent 6468a22 commit 31cc9cb
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 46 deletions.
2 changes: 1 addition & 1 deletion examples/e2e/.copier-answers.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Changes here will be overwritten by Copier
_commit: 2023.12.06-4-g1e3edc6
_commit: 2023.12.12
_src_path: gh:zenml-io/template-e2e-batch
data_quality_checks: true
email: ''
Expand Down
5 changes: 4 additions & 1 deletion src/zenml/artifacts/external_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ class ExternalArtifact(ExternalArtifactConfiguration):
`version`, `pipeline_run_name`, or `pipeline_name` are set, the
latest version of the artifact will be used.
version: Version of the artifact to search. Only used when `name` is
provided.
provided. Cannot be used together with `model_version`.
model_version: The model version to search in. Only used when `name`
is provided. Cannot be used together with `version`.
materializer: The materializer to use for saving the artifact value
to the artifact store. Only used when `value` is provided.
store_artifact_metadata: Whether metadata for the artifact should
Expand Down Expand Up @@ -147,4 +149,5 @@ def config(self) -> ExternalArtifactConfiguration:
id=self.id,
name=self.name,
version=self.version,
model_version=self.model_version,
)
35 changes: 19 additions & 16 deletions src/zenml/artifacts/external_artifact_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,14 @@
# or implied. See the License for the specific language governing
# permissions and limitations under the License.
"""External artifact definition."""
from typing import TYPE_CHECKING, Optional
from typing import Any, Dict, Optional
from uuid import UUID

from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel, root_validator

from zenml.logger import get_logger
from zenml.models.v2.core.artifact import ArtifactResponse

if TYPE_CHECKING:
from zenml.model.model_version import ModelVersion
from zenml.model.model_version import ModelVersion
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse

logger = get_logger(__name__)

Expand All @@ -35,11 +33,16 @@ class ExternalArtifactConfiguration(BaseModel):
id: Optional[UUID] = None
name: Optional[str] = None
version: Optional[str] = None

_model_version: Optional["ModelVersion"] = PrivateAttr(None)

def _set_model_version(self, model_version: "ModelVersion") -> None:
self._model_version = model_version
model_version: Optional[ModelVersion] = None

@root_validator
def _validate_all_eac(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values.get("version", None) and values.get("model_version", None):
raise ValueError(
"Cannot provide both `version` and `model_version` when "
"creating an external artifact."
)
return values

def get_artifact_version_id(self) -> UUID:
"""Get the artifact.
Expand All @@ -64,13 +67,13 @@ def get_artifact_version_id(self) -> UUID:
response = client.get_artifact_version(
self.name, version=self.version
)
elif self._model_version:
response_ = self._model_version.get_artifact(self.name)
if not isinstance(response_, ArtifactResponse):
elif self.model_version:
response_ = self.model_version.get_artifact(self.name)
if not isinstance(response_, ArtifactVersionResponse):
raise RuntimeError(
f"Failed to pull artifact `{self.name}` from the Model "
f"Version (name=`{self._model_version.name}`, version="
f"`{self._model_version.version}`). Please validate the "
f"Version (name=`{self.model_version.name}`, version="
f"`{self.model_version.version}`). Please validate the "
"input and try again."
)
response = response_
Expand Down
3 changes: 1 addition & 2 deletions src/zenml/model/model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,7 @@ def _try_get_as_external_artifact(
except RuntimeError:
return None

ea = ExternalArtifact(name=name, version=version)
ea._set_model_version(self)
ea = ExternalArtifact(name=name, version=version, model_version=self)
return ea

def get_artifact(
Expand Down
16 changes: 6 additions & 10 deletions src/zenml/orchestrators/step_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,17 +731,13 @@ def _link_pipeline_run_to_model_from_artifacts(

# Add models from external artifacts
for external_artifact in external_artifacts:
try:
artifact_version_id = (
external_artifact.get_artifact_version_id()
)
links = client.list_model_version_artifact_links(
artifact_version_id=artifact_version_id,
if external_artifact.model_version:
models.add(
(
external_artifact.model_version.model_id,
external_artifact.model_version.id,
)
)
for link in links:
models.add((link.model, link.model_version))
except RuntimeError: # artifacts uploaded by value have no models
pass

for model in models:
client.zen_store.create_model_version_pipeline_run_link(
Expand Down
79 changes: 63 additions & 16 deletions tests/integration/functional/steps/test_model_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pytest
from typing_extensions import Annotated

from zenml import get_step_context, pipeline, step
from zenml import get_pipeline_context, get_step_context, pipeline, step
from zenml.artifacts.artifact_config import ArtifactConfig
from zenml.artifacts.external_artifact import ExternalArtifact
from zenml.client import Client
Expand Down Expand Up @@ -601,14 +601,6 @@ def _consumer_pipeline_with_step_context():
)(ExternalArtifact(name="output_0"), 1)


@pipeline
def _consumer_pipeline_with_artifact_context():
_consumer_step(
ExternalArtifact(name="output_1"),
2,
)


@pipeline(model_version=ModelVersion(name="step", version=ModelStages.LATEST))
def _consumer_pipeline_with_pipeline_context():
_consumer_step(
Expand All @@ -629,31 +621,26 @@ def test_that_consumption_also_registers_run_in_model_version(
producer_run = f"producer_run_{uuid4()}"
consumer_run_1 = f"consumer_run_1_{uuid4()}"
consumer_run_2 = f"consumer_run_2_{uuid4()}"
consumer_run_3 = f"consumer_run_3_{uuid4()}"
_producer_pipeline.with_options(
run_name=producer_run, enable_cache=False
)()
_consumer_pipeline_with_step_context.with_options(
run_name=consumer_run_1
)()
_consumer_pipeline_with_artifact_context.with_options(
run_name=consumer_run_2
)()
_consumer_pipeline_with_pipeline_context.with_options(
run_name=consumer_run_3
run_name=consumer_run_2
)()

mv = clean_client.get_model_version(
model_name_or_id="step",
model_version_name_or_number_or_id=ModelStages.LATEST,
)

assert len(mv.pipeline_run_ids) == 4
assert len(mv.pipeline_run_ids) == 3
assert {run_name for run_name in mv.pipeline_run_ids} == {
producer_run,
consumer_run_1,
consumer_run_2,
consumer_run_3,
}


Expand Down Expand Up @@ -800,3 +787,63 @@ def _inner_pipeline():
)
== 0
)


@step
def _this_step_asserts_context_with_artifact(artifact: str):
"""Assert given arg with model_version number."""
assert artifact == str(get_step_context().model_version.id)


@step
def _this_step_produces_output_model_version() -> (
Annotated[str, ArtifactConfig(name="artifact")]
):
"""This step produces artifact with model_version number."""
return str(get_step_context().model_version.id)


def test_pipeline_context_pass_artifact_from_model_and_link_run(
clean_client: "Client",
):
"""Test that ExternalArtifact from pipeline context is matched to proper version and run is linked."""

@pipeline(model_version=ModelVersion(name="pipeline"), enable_cache=False)
def _producer(do_promote: bool):
_this_step_produces_output_model_version()
if do_promote:
get_pipeline_context().model_version.set_stage(
ModelStages.PRODUCTION
)

@pipeline(
model_version=ModelVersion(
name="pipeline", version=ModelStages.PRODUCTION
),
enable_cache=False,
)
def _consumer():
artifact = get_pipeline_context().model_version.get_artifact(
"artifact"
)
_this_step_asserts_context_with_artifact(artifact)

_producer.with_options(run_name="run_1")(True)
_producer.with_options(run_name="run_2")(False)
_consumer.with_options(run_name="run_3")()

mv = clean_client.get_model_version(
model_name_or_id="pipeline",
model_version_name_or_number_or_id=ModelStages.LATEST,
)

assert len(mv.pipeline_run_ids) == 1
assert {run_name for run_name in mv.pipeline_run_ids} == {"run_2"}

mv = clean_client.get_model_version(
model_name_or_id="pipeline",
model_version_name_or_number_or_id=ModelStages.PRODUCTION,
)

assert len(mv.pipeline_run_ids) == 2
assert {run_name for run_name in mv.pipeline_run_ids} == {"run_1", "run_3"}

0 comments on commit 31cc9cb

Please sign in to comment.