Skip to content

Commit 96a27e1

Browse files
committed
Add some docs, enhance naming and handle Pipeline inputs too
1 parent 49b98b2 commit 96a27e1

File tree

2 files changed

+50
-35
lines changed

2 files changed

+50
-35
lines changed

src/hayhooks/server/pipelines/models.py

+13-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from hayhooks.server.utils.create_valid_type import create_valid_type
21
from pandas import DataFrame
3-
from pydantic import BaseModel, create_model, ConfigDict
2+
from pydantic import BaseModel, ConfigDict, create_model
3+
4+
from hayhooks.server.utils.create_valid_type import handle_unsupported_types
45

56

67
class PipelineDefinition(BaseModel):
@@ -23,13 +24,16 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
2324
config = ConfigDict(arbitrary_types_allowed=True)
2425

2526
for component_name, inputs in pipeline_inputs.items():
26-
2727
component_model = {}
2828
for name, typedef in inputs.items():
29-
component_model[name] = (typedef["type"], typedef.get("default_value", ...))
30-
request_model[component_name] = (create_model('ComponentParams', **component_model, __config__=config), ...)
29+
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
30+
component_model[name] = (
31+
input_type,
32+
typedef.get("default_value", ...),
33+
)
34+
request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)
3135

32-
return create_model(f'{pipeline_name.capitalize()}RunRequest', **request_model, __config__=config)
36+
return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config)
3337

3438

3539
def get_response_model(pipeline_name: str, pipeline_outputs):
@@ -47,10 +51,10 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
4751
component_model = {}
4852
for name, typedef in outputs.items():
4953
output_type = typedef["type"]
50-
component_model[name] = (create_valid_type(output_type, { DataFrame: dict}), ...)
51-
response_model[component_name] = (create_model('ComponentParams', **component_model, __config__=config), ...)
54+
component_model[name] = (handle_unsupported_types(output_type, {DataFrame: dict}), ...)
55+
response_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)
5256

53-
return create_model(f'{pipeline_name.capitalize()}RunResponse', **response_model, __config__=config)
57+
return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)
5458

5559

5660
def convert_component_output(component_output):
+37-26
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,43 @@
1-
from typing import get_type_hints, Dict, get_origin, get_args
2-
from typing_extensions import TypedDict
3-
from types import GenericAlias
41
from inspect import isclass
2+
from types import GenericAlias
3+
from typing import Dict, Union, get_args, get_origin, get_type_hints
4+
5+
from typing_extensions import TypedDict
6+
57

6-
def create_valid_type(typed_object:type, invalid_types:Dict[type, type]):
7-
"""
8-
Returns a new type specification, replacing invalid_types in typed_object.
9-
example: replace_invalid_types(ExtractedAnswer, {DataFrame: List}]) returns
10-
a TypedDict with DataFrame types replaced with List
8+
def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
119
"""
12-
def validate_type(v):
10+
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
11+
12+
:param type_: Type to replace if not supported
13+
:param types_mapping: Mapping of types to replace
14+
"""
15+
16+
def _handle_generics(t_) -> GenericAlias:
17+
"""
18+
Handle generics recursively
19+
"""
1320
child_typing = []
14-
for t in get_args(v):
15-
if t in invalid_types:
16-
result = invalid_types[t]
21+
for t in get_args(t_):
22+
if t in types_mapping:
23+
result = types_mapping[t]
1724
elif isclass(t):
18-
result = create_valid_type(t, invalid_types)
19-
else: result = t
25+
result = handle_unsupported_types(t, types_mapping)
26+
else:
27+
result = t
2028
child_typing.append(result)
21-
return GenericAlias(get_origin(v), tuple(child_typing))
22-
if isclass(typed_object):
23-
new_typing = {}
24-
for k, v in get_type_hints(typed_object).items():
25-
if(get_args(v) != ()):
26-
new_typing[k] = validate_type(v)
27-
else: new_typing[k] = v
28-
if new_typing == {}:
29-
return typed_object
30-
else: return TypedDict(typed_object.__name__, new_typing)
31-
else:
32-
return validate_type(typed_object)
29+
return GenericAlias(get_origin(t_), tuple(child_typing))
30+
31+
if isclass(type_):
32+
new_type = {}
33+
for arg_name, arg_type in get_type_hints(type_).items():
34+
if get_args(arg_type):
35+
new_type[arg_name] = _handle_generics(arg_type)
36+
else:
37+
new_type[arg_name] = arg_type
38+
if new_type:
39+
return TypedDict(type_.__name__, new_type)
40+
41+
return type_
42+
43+
return _handle_generics(type_)

0 commit comments

Comments
 (0)