Skip to content

Commit dd52385

Browse files
Moving tags to the body for artifacts and artifact versions (#2138)
* moving tags * Auto-update of E2E template --------- Co-authored-by: GitHub Actions <[email protected]>
1 parent 9a67d5f commit dd52385

File tree

7 files changed

+41
-38
lines changed

7 files changed

+41
-38
lines changed

examples/e2e/.copier-answers.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Changes here will be overwritten by Copier
2-
_commit: 2023.11.23-2-gc19b794
2+
_commit: 2023.12.06
33
_src_path: gh:zenml-io/template-e2e-batch
44
data_quality_checks: true
55
email: ''

examples/e2e/configs/train_config.yaml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,9 @@ steps:
3232
model_trainer:
3333
parameters:
3434
name: e2e_use_case
35-
compute_performance_metrics_on_current_data:
36-
parameters:
37-
target_env: staging
3835
promote_with_metric_compare:
3936
parameters:
4037
mlflow_model_name: e2e_use_case
41-
target_env: staging
4238
notify_on_success:
4339
parameters:
4440
notify_on_success: False
@@ -65,6 +61,9 @@ model_version:
6561
# pipeline level extra configurations
6662
extra:
6763
notify_on_failure: True
64+
# pipeline level parameters
65+
parameters:
66+
target_env: staging
6867
# This set contains all the model configurations that you want
6968
# to evaluate during hyperparameter tuning stage.
7069
model_search_space:

examples/e2e/pipelines/training.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818

1919
import random
20-
from typing import List, Optional
20+
from typing import Any, Dict, List, Optional
2121

2222
from steps import (
2323
compute_performance_metrics_on_current_data,
@@ -33,14 +33,16 @@
3333
train_data_splitter,
3434
)
3535

36-
from zenml import get_pipeline_context, pipeline
36+
from zenml import pipeline
3737
from zenml.logger import get_logger
3838

3939
logger = get_logger(__name__)
4040

4141

4242
@pipeline(on_failure=notify_on_failure)
4343
def e2e_use_case_training(
44+
model_search_space: Dict[str, Any],
45+
target_env: str,
4446
test_size: float = 0.2,
4547
drop_na: Optional[bool] = None,
4648
normalize: Optional[bool] = None,
@@ -57,6 +59,8 @@ def e2e_use_case_training(
5759
trains and evaluates a model.
5860
5961
Args:
62+
model_search_space: Search space for hyperparameter tuning
63+
target_env: The environment to promote the model to
6064
test_size: Size of holdout set for training 0.0..1.0
6165
drop_na: If `True` NA values will be removed from dataset
6266
normalize: If `True` dataset will be normalized with MinMaxScaler
@@ -65,12 +69,10 @@ def e2e_use_case_training(
6569
min_test_accuracy: Threshold to stop execution if test set accuracy is lower
6670
fail_on_accuracy_quality_gates: If `True` and `min_train_accuracy` or `min_test_accuracy`
6771
are not met - execution will be interrupted early
68-
6972
"""
7073
### ADD YOUR OWN CODE HERE - THIS IS JUST AN EXAMPLE ###
7174
# Link all the steps together by calling them and passing the output
7275
# of one step as the input of the next step.
73-
pipeline_extra = get_pipeline_context().extra
7476
########## ETL stage ##########
7577
raw_data, target, _ = data_loader(random_state=random.randint(0, 100))
7678
dataset_trn, dataset_tst = train_data_splitter(
@@ -87,9 +89,7 @@ def e2e_use_case_training(
8789
########## Hyperparameter tuning stage ##########
8890
after = []
8991
search_steps_prefix = "hp_tuning_search_"
90-
for config_name, model_search_configuration in pipeline_extra[
91-
"model_search_space"
92-
].items():
92+
for config_name, model_search_configuration in model_search_space.items():
9393
step_name = f"{search_steps_prefix}{config_name}"
9494
hp_tuning_single_search(
9595
id=step_name,
@@ -123,12 +123,15 @@ def e2e_use_case_training(
123123
latest_metric,
124124
current_metric,
125125
) = compute_performance_metrics_on_current_data(
126-
dataset_tst=dataset_tst, after=["model_evaluator"]
126+
dataset_tst=dataset_tst,
127+
target_env=target_env,
128+
after=["model_evaluator"],
127129
)
128130

129131
promote_with_metric_compare(
130132
latest_metric=latest_metric,
131133
current_metric=current_metric,
134+
target_env=target_env,
132135
)
133136
last_step = "promote_with_metric_compare"
134137

src/zenml/models/v2/core/artifact.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ class ArtifactUpdate(BaseModel):
6363
class ArtifactResponseBody(BaseResponseBody):
6464
"""Response body for artifacts."""
6565

66+
tags: List[TagResponseModel] = Field(
67+
title="Tags associated with the model",
68+
)
69+
6670

6771
class ArtifactResponseMetadata(BaseResponseMetadata):
6872
"""Response metadata for artifacts."""
@@ -71,9 +75,6 @@ class ArtifactResponseMetadata(BaseResponseMetadata):
7175
title="Whether the name is custom (True) or auto-generated (False).",
7276
default=False,
7377
)
74-
tags: List[TagResponseModel] = Field(
75-
title="Tags associated with the model",
76-
)
7778

7879

7980
class ArtifactResponse(
@@ -98,22 +99,22 @@ def get_hydrated_version(self) -> "ArtifactResponse":
9899

99100
# Body and metadata properties
100101
@property
101-
def has_custom_name(self) -> bool:
102-
"""The `has_custom_name` property.
102+
def tags(self) -> List[TagResponseModel]:
103+
"""The `tags` property.
103104
104105
Returns:
105106
the value of the property.
106107
"""
107-
return self.get_metadata().has_custom_name
108+
return self.get_body().tags
108109

109110
@property
110-
def tags(self) -> List[TagResponseModel]:
111-
"""The `tags` property.
111+
def has_custom_name(self) -> bool:
112+
"""The `has_custom_name` property.
112113
113114
Returns:
114115
the value of the property.
115116
"""
116-
return self.get_metadata().tags
117+
return self.get_metadata().has_custom_name
117118

118119
# Helper methods
119120
@property

src/zenml/models/v2/core/artifact_version.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
134134
data_type: Source = Field(
135135
title="Data type of the artifact.",
136136
)
137+
tags: List[TagResponseModel] = Field(
138+
title="Tags associated with the model",
139+
)
137140

138141
_convert_source = convert_source_validator("materializer", "data_type")
139142

@@ -149,9 +152,6 @@ class ArtifactVersionResponseMetadata(WorkspaceScopedResponseMetadata):
149152
title="ID of the step run that produced this artifact.",
150153
default=None,
151154
)
152-
tags: List[TagResponseModel] = Field(
153-
title="Tags associated with the model",
154-
)
155155
visualizations: Optional[List["ArtifactVisualizationResponse"]] = Field(
156156
default=None, title="Visualizations of the artifact."
157157
)
@@ -215,31 +215,31 @@ def type(self) -> ArtifactType:
215215
return self.get_body().type
216216

217217
@property
218-
def artifact_store_id(self) -> Optional[UUID]:
219-
"""The `artifact_store_id` property.
218+
def tags(self) -> List[TagResponseModel]:
219+
"""The `tags` property.
220220
221221
Returns:
222222
the value of the property.
223223
"""
224-
return self.get_metadata().artifact_store_id
224+
return self.get_body().tags
225225

226226
@property
227-
def producer_step_run_id(self) -> Optional[UUID]:
228-
"""The `producer_step_run_id` property.
227+
def artifact_store_id(self) -> Optional[UUID]:
228+
"""The `artifact_store_id` property.
229229
230230
Returns:
231231
the value of the property.
232232
"""
233-
return self.get_metadata().producer_step_run_id
233+
return self.get_metadata().artifact_store_id
234234

235235
@property
236-
def tags(self) -> List[TagResponseModel]:
237-
"""The `tags` property.
236+
def producer_step_run_id(self) -> Optional[UUID]:
237+
"""The `producer_step_run_id` property.
238238
239239
Returns:
240240
the value of the property.
241241
"""
242-
return self.get_metadata().tags
242+
return self.get_metadata().producer_step_run_id
243243

244244
@property
245245
def visualizations(

src/zenml/zen_stores/schemas/artifact_schemas.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,14 @@ def to_model(self, hydrate: bool = False) -> ArtifactResponse:
113113
body = ArtifactResponseBody(
114114
created=self.created,
115115
updated=self.updated,
116+
tags=[t.tag.to_model() for t in self.tags],
116117
)
117118

118119
# Create the metadata of the model
119120
metadata = None
120121
if hydrate:
121122
metadata = ArtifactResponseMetadata(
122123
has_custom_name=self.has_custom_name,
123-
tags=[t.tag.to_model() for t in self.tags],
124124
)
125125

126126
return ArtifactResponse(
@@ -299,6 +299,7 @@ def to_model(self, hydrate: bool = False) -> ArtifactVersionResponse:
299299
data_type=data_type,
300300
created=self.created,
301301
updated=self.updated,
302+
tags=[t.tag.to_model() for t in self.tags],
302303
)
303304

304305
# Create the metadata of the model
@@ -318,7 +319,6 @@ def to_model(self, hydrate: bool = False) -> ArtifactVersionResponse:
318319
producer_step_run_id=producer_step_run_id,
319320
visualizations=[v.to_model() for v in self.visualizations],
320321
run_metadata={m.key: m.to_model() for m in self.run_metadata},
321-
tags=[t.tag.to_model() for t in self.tags],
322322
)
323323

324324
return ArtifactVersionResponse(

tests/unit/conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -484,10 +484,10 @@ def sample_artifact_model() -> ArtifactResponse:
484484
body=ArtifactResponseBody(
485485
created=datetime.now(),
486486
updated=datetime.now(),
487+
tags=[],
487488
),
488489
metadata=ArtifactResponseMetadata(
489490
has_custom_name=True,
490-
tags=[],
491491
),
492492
)
493493

@@ -509,10 +509,10 @@ def sample_artifact_version_model(
509509
type=ArtifactType.DATA,
510510
materializer="sample_module.sample_materializer",
511511
data_type="sample_module.sample_data_type",
512+
tags=[],
512513
),
513514
metadata=ArtifactVersionResponseMetadata(
514515
workspace=sample_workspace_model,
515-
tags=[],
516516
),
517517
)
518518

0 commit comments

Comments
 (0)