Skip to content

Commit cc15953

Browse files
Enhance stack validation (#148)
* Added spec validations in Stack and Component pydantic models. Added check for mismatch between stack and component provider to yaml_utils.py * Working out testing changes * Changed tests and added test_utils * Made changes according to formatter and linter. * Made final changes for pull request. Got rid of comments and print calls. * Removed a comment in yaml_utils.py --------- Co-authored-by: Alex Strick van Linschoten <[email protected]>
1 parent c9f4ab9 commit cc15953

10 files changed

+319
-24
lines changed

src/mlstacks/constants.py

+62
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# permissions and limitations under the License.
1313
"""MLStacks constants."""
1414

15+
from typing import Dict, List
16+
1517
MLSTACKS_PACKAGE_NAME = "mlstacks"
1618
MLSTACKS_INITIALIZATION_FILE_FLAG = "IGNORE_ME"
1719
MLSTACKS_STACK_COMPONENT_FLAGS = [
@@ -39,6 +41,52 @@
3941
"model_deployer": ["seldon"],
4042
"step_operator": ["sagemaker", "vertex"],
4143
}
44+
ALLOWED_COMPONENT_TYPES: Dict[str, Dict[str, List[str]]] = {
45+
"aws": {
46+
"artifact_store": ["s3"],
47+
"container_registry": ["aws"],
48+
"experiment_tracker": ["mlflow"],
49+
"orchestrator": [
50+
"kubeflow",
51+
"kubernetes",
52+
"sagemaker",
53+
"skypilot",
54+
"tekton",
55+
],
56+
"mlops_platform": ["zenml"],
57+
"model_deployer": ["seldon"],
58+
"step_operator": ["sagemaker"],
59+
},
60+
"azure": {},
61+
"gcp": {
62+
"artifact_store": ["gcp"],
63+
"container_registry": ["gcp"],
64+
"experiment_tracker": ["mlflow"],
65+
"orchestrator": [
66+
"kubeflow",
67+
"kubernetes",
68+
"skypilot",
69+
"tekton",
70+
"vertex",
71+
],
72+
"mlops_platform": ["zenml"],
73+
"model_deployer": ["seldon"],
74+
"step_operator": ["vertex"],
75+
},
76+
"k3d": {
77+
"artifact_store": ["minio"],
78+
"container_registry": ["default"],
79+
"experiment_tracker": ["mlflow"],
80+
"orchestrator": [
81+
"kubeflow",
82+
"kubernetes",
83+
"sagemaker",
84+
"tekton",
85+
],
86+
"mlops_platform": ["zenml"],
87+
"model_deployer": ["seldon"],
88+
},
89+
}
4290

4391
PERMITTED_NAME_REGEX = r"^[a-zA-Z0-9][a-zA-Z0-9_-]*$"
4492
ANALYTICS_OPT_IN_ENV_VARIABLE = "MLSTACKS_ANALYTICS_OPT_IN"
@@ -49,5 +97,19 @@
4997
"contain alphanumeric characters, underscores, and hyphens "
5098
"thereafter."
5199
)
100+
INVALID_COMPONENT_TYPE_ERROR_MESSAGE = (
101+
"Artifact Store, Container Registry, Experiment Tracker, Orchestrator, "
102+
"MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d "
103+
"providers. Step Operator may only be used with aws and gcp."
104+
)
105+
INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE = (
106+
"Only certain flavors are allowed for a given provider-component type "
107+
"combination. For more information, consult the tables for your specified "
108+
"provider at the MLStacks documentation: "
109+
"https://mlstacks.zenml.io/stacks/stack-specification."
110+
)
111+
STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE = (
112+
"Stack provider and component provider mismatch."
113+
)
52114
DEFAULT_REMOTE_STATE_BUCKET_NAME = "zenml-mlstacks-remote-state"
53115
TERRAFORM_CONFIG_BUCKET_REPLACEMENT_STRING = "BUCKETNAMEREPLACEME"

src/mlstacks/enums.py

+20
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class ComponentFlavorEnum(str, Enum):
4949
TEKTON = "tekton"
5050
VERTEX = "vertex"
5151
ZENML = "zenml"
52+
DEFAULT = "default"
5253

5354

5455
class DeploymentMethodEnum(str, Enum):
@@ -77,3 +78,22 @@ class AnalyticsEventsEnum(str, Enum):
7778
MLSTACKS_SOURCE = "MLStacks Source"
7879
MLSTACKS_EXCEPTION = "MLStacks Exception"
7980
MLSTACKS_VERSION = "MLStacks Version"
81+
82+
83+
class SpecTypeEnum(str, Enum):
84+
"""Spec type enum."""
85+
86+
STACK = "stack"
87+
COMPONENT = "component"
88+
89+
90+
class StackSpecVersionEnum(int, Enum):
91+
"""Spec version enum."""
92+
93+
ONE = 1
94+
95+
96+
class ComponentSpecVersionEnum(int, Enum):
97+
"""Spec version enum."""
98+
99+
ONE = 1

src/mlstacks/models/component.py

+70-7
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,27 @@
1212
# permissions and limitations under the License.
1313
"""Component model."""
1414

15-
from typing import Dict, Optional
15+
from typing import Any, Dict, Optional
1616

1717
from pydantic import BaseModel, validator
1818

19-
from mlstacks.constants import INVALID_NAME_ERROR_MESSAGE
19+
from mlstacks.constants import (
20+
INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE,
21+
INVALID_COMPONENT_TYPE_ERROR_MESSAGE,
22+
INVALID_NAME_ERROR_MESSAGE,
23+
)
2024
from mlstacks.enums import (
2125
ComponentFlavorEnum,
26+
ComponentSpecVersionEnum,
2227
ComponentTypeEnum,
2328
ProviderEnum,
29+
SpecTypeEnum,
30+
)
31+
from mlstacks.utils.model_utils import (
32+
is_valid_component_flavor,
33+
is_valid_component_type,
34+
is_valid_name,
2435
)
25-
from mlstacks.utils.model_utils import is_valid_name
2636

2737

2838
class ComponentMetadata(BaseModel):
@@ -49,16 +59,16 @@ class Component(BaseModel):
4959
metadata: The metadata of the component.
5060
"""
5161

52-
spec_version: int = 1
53-
spec_type: str = "component"
62+
spec_version: ComponentSpecVersionEnum = ComponentSpecVersionEnum.ONE
63+
spec_type: SpecTypeEnum = SpecTypeEnum.COMPONENT
5464
name: str
65+
provider: ProviderEnum
5566
component_type: ComponentTypeEnum
5667
component_flavor: ComponentFlavorEnum
57-
provider: ProviderEnum
5868
metadata: Optional[ComponentMetadata] = None
5969

6070
@validator("name")
61-
def validate_name(cls, name: str) -> str: # noqa: N805
71+
def validate_name(cls, name: str) -> str: # noqa
6272
"""Validate the name.
6373
6474
Name must start with an alphanumeric character and can only contain
@@ -78,3 +88,56 @@ def validate_name(cls, name: str) -> str: # noqa: N805
7888
if not is_valid_name(name):
7989
raise ValueError(INVALID_NAME_ERROR_MESSAGE)
8090
return name
91+
92+
@validator("component_type")
93+
def validate_component_type(
94+
cls, # noqa
95+
component_type: str,
96+
values: Dict[str, Any],
97+
) -> str:
98+
"""Validate the component type.
99+
100+
Artifact Store, Container Registry, Experiment Tracker, Orchestrator,
101+
MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d
102+
providers. Step Operator may only be used with aws and gcp.
103+
104+
Args:
105+
component_type: The component type.
106+
values: The previously validated component specs.
107+
108+
Returns:
109+
The validated component type.
110+
111+
Raises:
112+
ValueError: If the component type is invalid.
113+
"""
114+
if not is_valid_component_type(component_type, values["provider"]):
115+
raise ValueError(INVALID_COMPONENT_TYPE_ERROR_MESSAGE)
116+
return component_type
117+
118+
@validator("component_flavor")
119+
def validate_component_flavor(
120+
cls, # noqa
121+
component_flavor: str,
122+
values: Dict[str, Any],
123+
) -> str:
124+
"""Validate the component flavor.
125+
126+
Only certain flavors are allowed for a given provider-component
127+
type combination. For more information, consult the tables for
128+
your specified provider at the MLStacks documentation:
129+
https://mlstacks.zenml.io/stacks/stack-specification.
130+
131+
Args:
132+
component_flavor: The component flavor.
133+
values: The previously validated component specs.
134+
135+
Returns:
136+
The validated component flavor.
137+
138+
Raises:
139+
ValueError: If the component flavor is invalid.
140+
"""
141+
if not is_valid_component_flavor(component_flavor, values):
142+
raise ValueError(INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE)
143+
return component_flavor

src/mlstacks/models/stack.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from mlstacks.enums import (
2020
DeploymentMethodEnum,
2121
ProviderEnum,
22+
SpecTypeEnum,
23+
StackSpecVersionEnum,
2224
)
2325
from mlstacks.models.component import Component
2426
from mlstacks.utils.model_utils import is_valid_name
@@ -38,8 +40,8 @@ class Stack(BaseModel):
3840
components: The components of the stack.
3941
"""
4042

41-
spec_version: int = 1
42-
spec_type: str = "stack"
43+
spec_version: StackSpecVersionEnum = StackSpecVersionEnum.ONE
44+
spec_type: SpecTypeEnum = SpecTypeEnum.STACK
4345
name: str
4446
provider: ProviderEnum
4547
default_region: Optional[str]
@@ -50,7 +52,7 @@ class Stack(BaseModel):
5052
components: List[Component] = []
5153

5254
@validator("name")
53-
def validate_name(cls, name: str) -> str: # noqa: N805
55+
def validate_name(cls, name: str) -> str: # noqa
5456
"""Validate the name.
5557
5658
Name must start with an alphanumeric character and can only contain

src/mlstacks/utils/model_utils.py

+45-1
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
"""Util functions for Pydantic models and validation."""
1414

1515
import re
16+
from typing import Any, Dict
1617

17-
from mlstacks.constants import PERMITTED_NAME_REGEX
18+
from mlstacks.constants import ALLOWED_COMPONENT_TYPES, PERMITTED_NAME_REGEX
1819

1920

2021
def is_valid_name(name: str) -> bool:
@@ -29,3 +30,46 @@ def is_valid_name(name: str) -> bool:
2930
True if the name is valid, False otherwise.
3031
"""
3132
return re.match(PERMITTED_NAME_REGEX, name) is not None
33+
34+
35+
def is_valid_component_type(component_type: str, provider: str) -> bool:
36+
"""Check if the component type is valid.
37+
38+
Used for components.
39+
40+
Args:
41+
component_type: The component type.
42+
provider: The provider.
43+
44+
Returns:
45+
True if the component type is valid, False otherwise.
46+
"""
47+
allowed_types = list(ALLOWED_COMPONENT_TYPES[provider].keys())
48+
return component_type in allowed_types
49+
50+
51+
def is_valid_component_flavor(
52+
component_flavor: str, specs: Dict[str, Any]
53+
) -> bool:
54+
"""Check if the component flavor is valid.
55+
56+
Used for components.
57+
58+
Args:
59+
component_flavor: The component flavor.
60+
specs: The previously validated component specs.
61+
62+
Returns:
63+
True if the component flavor is valid, False otherwise.
64+
"""
65+
try:
66+
is_valid = (
67+
component_flavor
68+
in ALLOWED_COMPONENT_TYPES[specs["provider"]][
69+
specs["component_type"]
70+
]
71+
)
72+
except KeyError:
73+
return False
74+
75+
return is_valid

src/mlstacks/utils/yaml_utils.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import yaml
1818

19+
from mlstacks.constants import STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE
1920
from mlstacks.models.component import (
2021
Component,
2122
ComponentMetadata,
@@ -57,9 +58,17 @@ def load_component_yaml(path: str) -> Component:
5758
5859
Returns:
5960
The component model.
61+
62+
Raises:
63+
FileNotFoundError: If the file is not found.
6064
"""
61-
with open(path) as file:
62-
component_data = yaml.safe_load(file)
65+
try:
66+
with open(path) as file:
67+
component_data = yaml.safe_load(file)
68+
except FileNotFoundError as exc:
69+
error_message = f"""Component file at "{path}" specified in
70+
the stack spec file could not be found."""
71+
raise FileNotFoundError(error_message) from exc
6372

6473
if component_data.get("metadata") is None:
6574
component_data["metadata"] = {}
@@ -88,14 +97,18 @@ def load_stack_yaml(path: str) -> Stack:
8897
8998
Returns:
9099
The stack model.
100+
101+
Raises:
102+
ValueError: If the stack and component have different providers
91103
"""
92104
with open(path) as yaml_file:
93105
stack_data = yaml.safe_load(yaml_file)
94106
component_data = stack_data.get("components")
95107

96108
if component_data is None:
97109
component_data = []
98-
return Stack(
110+
111+
stack = Stack(
99112
spec_version=stack_data.get("spec_version"),
100113
spec_type=stack_data.get("spec_type"),
101114
name=stack_data.get("name"),
@@ -107,3 +120,9 @@ def load_stack_yaml(path: str) -> Stack:
107120
load_component_yaml(component) for component in component_data
108121
],
109122
)
123+
124+
for component in stack.components:
125+
if component.provider != stack.provider:
126+
raise ValueError(STACK_COMPONENT_PROVIDER_MISMATCH_ERROR_MESSAGE)
127+
128+
return stack

0 commit comments

Comments
 (0)