12
12
# permissions and limitations under the License.
13
13
"""Component model."""
14
14
15
- from typing import Dict , Optional
15
+ from typing import Any , Dict , Optional
16
16
17
17
from pydantic import BaseModel , validator
18
18
19
- from mlstacks .constants import INVALID_NAME_ERROR_MESSAGE
19
+ from mlstacks .constants import (
20
+ INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE ,
21
+ INVALID_COMPONENT_TYPE_ERROR_MESSAGE ,
22
+ INVALID_NAME_ERROR_MESSAGE ,
23
+ )
20
24
from mlstacks .enums import (
21
25
ComponentFlavorEnum ,
26
+ ComponentSpecVersionEnum ,
22
27
ComponentTypeEnum ,
23
28
ProviderEnum ,
29
+ SpecTypeEnum ,
30
+ )
31
+ from mlstacks .utils .model_utils import (
32
+ is_valid_component_flavor ,
33
+ is_valid_component_type ,
34
+ is_valid_name ,
24
35
)
25
- from mlstacks .utils .model_utils import is_valid_name
26
36
27
37
28
38
class ComponentMetadata (BaseModel ):
@@ -49,16 +59,16 @@ class Component(BaseModel):
49
59
metadata: The metadata of the component.
50
60
"""
51
61
52
- spec_version : int = 1
53
- spec_type : str = "component"
62
+ spec_version : ComponentSpecVersionEnum = ComponentSpecVersionEnum . ONE
63
+ spec_type : SpecTypeEnum = SpecTypeEnum . COMPONENT
54
64
name : str
65
+ provider : ProviderEnum
55
66
component_type : ComponentTypeEnum
56
67
component_flavor : ComponentFlavorEnum
57
- provider : ProviderEnum
58
68
metadata : Optional [ComponentMetadata ] = None
59
69
60
70
@validator ("name" )
61
- def validate_name (cls , name : str ) -> str : # noqa: N805
71
+ def validate_name (cls , name : str ) -> str : # noqa
62
72
"""Validate the name.
63
73
64
74
Name must start with an alphanumeric character and can only contain
@@ -78,3 +88,56 @@ def validate_name(cls, name: str) -> str: # noqa: N805
78
88
if not is_valid_name (name ):
79
89
raise ValueError (INVALID_NAME_ERROR_MESSAGE )
80
90
return name
91
+
92
+ @validator ("component_type" )
93
+ def validate_component_type (
94
+ cls , # noqa
95
+ component_type : str ,
96
+ values : Dict [str , Any ],
97
+ ) -> str :
98
+ """Validate the component type.
99
+
100
+ Artifact Store, Container Registry, Experiment Tracker, Orchestrator,
101
+ MLOps Platform, and Model Deployer may be used with aws, gcp, and k3d
102
+ providers. Step Operator may only be used with aws and gcp.
103
+
104
+ Args:
105
+ component_type: The component type.
106
+ values: The previously validated component specs.
107
+
108
+ Returns:
109
+ The validated component type.
110
+
111
+ Raises:
112
+ ValueError: If the component type is invalid.
113
+ """
114
+ if not is_valid_component_type (component_type , values ["provider" ]):
115
+ raise ValueError (INVALID_COMPONENT_TYPE_ERROR_MESSAGE )
116
+ return component_type
117
+
118
+ @validator ("component_flavor" )
119
+ def validate_component_flavor (
120
+ cls , # noqa
121
+ component_flavor : str ,
122
+ values : Dict [str , Any ],
123
+ ) -> str :
124
+ """Validate the component flavor.
125
+
126
+ Only certain flavors are allowed for a given provider-component
127
+ type combination. For more information, consult the tables for
128
+ your specified provider at the MLStacks documentation:
129
+ https://mlstacks.zenml.io/stacks/stack-specification.
130
+
131
+ Args:
132
+ component_flavor: The component flavor.
133
+ values: The previously validated component specs.
134
+
135
+ Returns:
136
+ The validated component flavor.
137
+
138
+ Raises:
139
+ ValueError: If the component flavor is invalid.
140
+ """
141
+ if not is_valid_component_flavor (component_flavor , values ):
142
+ raise ValueError (INVALID_COMPONENT_FLAVOR_ERROR_MESSAGE )
143
+ return component_flavor
0 commit comments