Skip to content

Commit 0b5e6af

Browse files
Bugfix: run_metadata value returns string instead of other types (#2149)
* properly handle primitive types * extend tests * Auto-update of E2E template * Auto-update of NLP template --------- Co-authored-by: GitHub Actions <[email protected]>
1 parent 3cb30fb commit 0b5e6af

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

src/zenml/models/v2/core/run_metadata.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ class RunMetadataRequest(WorkspaceScopedRequest):
5151
title="The types of the metadata to be created.",
5252
)
5353

54+
class Config:
55+
"""Pydantic configuration."""
56+
57+
smart_union = True
58+
5459

5560
# ------------------ Update Model ------------------
5661

@@ -75,6 +80,11 @@ class RunMetadataResponseBody(WorkspaceScopedResponseBody):
7580
max_length=STR_FIELD_MAX_LENGTH,
7681
)
7782

83+
class Config:
84+
"""Pydantic configuration."""
85+
86+
smart_union = True
87+
7888

7989
class RunMetadataResponseMetadata(WorkspaceScopedResponseMetadata):
8090
"""Response metadata for run metadata."""

tests/integration/functional/artifacts/test_utils.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,17 @@ def test_log_artifact_metadata_existing(clean_client):
124124
artifact_name="meaning_of_life",
125125
artifact_version="1",
126126
)
127+
log_artifact_metadata(
128+
{
129+
"float": 1.0,
130+
"int": 1,
131+
"str": "1.0",
132+
"list_str": ["1.0", "2.0"],
133+
"list_floats": [1.0, 2.0],
134+
},
135+
artifact_name="meaning_of_life",
136+
artifact_version="1",
137+
)
127138

128139
artifact_1 = clean_client.get_artifact_version(
129140
"meaning_of_life", version="1"
@@ -132,6 +143,23 @@ def test_log_artifact_metadata_existing(clean_client):
132143
assert artifact_1.run_metadata["description"].value == "Aria is great!"
133144
assert "description_3" in artifact_1.run_metadata
134145
assert artifact_1.run_metadata["description_3"].value == "Axl is great!"
146+
assert "float" in artifact_1.run_metadata
147+
assert artifact_1.run_metadata["float"].value - 1.0 < 10e-6
148+
assert "int" in artifact_1.run_metadata
149+
assert artifact_1.run_metadata["int"].value == 1
150+
assert "str" in artifact_1.run_metadata
151+
assert artifact_1.run_metadata["str"].value == "1.0"
152+
assert "list_str" in artifact_1.run_metadata
153+
assert (
154+
len(set(artifact_1.run_metadata["list_str"].value) - {"1.0", "2.0"})
155+
== 0
156+
)
157+
assert "list_floats" in artifact_1.run_metadata
158+
for each in artifact_1.run_metadata["list_floats"].value:
159+
if 0.99 < each < 1.01:
160+
assert each - 1.0 < 10e-6
161+
else:
162+
assert each - 2.0 < 10e-6
135163

136164
artifact_2 = clean_client.get_artifact_version(
137165
"meaning_of_life", version="43"

0 commit comments

Comments
 (0)