Skip to content

Commit 112ae25

Browse files
authored
Merge pull request #31 from Rusteam/fix-typing
convert union type with none to optional
2 parents e4b18bb + a7f07bf commit 112ae25

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

src/hayhooks/server/pipelines/models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
2626
for component_name, inputs in pipeline_inputs.items():
2727
component_model = {}
2828
for name, typedef in inputs.items():
29-
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
29+
try:
30+
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
31+
except TypeError as e:
32+
print(f"ERROR at {component_name!r}, {name}: {typedef}")
33+
raise e
3034
component_model[name] = (
3135
input_type,
3236
typedef.get("default_value", ...),
@@ -70,7 +74,10 @@ def convert_component_output(component_output):
7074
"""
7175
result = {}
7276
for output_name, data in component_output.items():
73-
get_value = lambda data: data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data
77+
78+
def get_value(data):
79+
return data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data
80+
7481
if type(data) is list:
7582
result[output_name] = [get_value(d) for d in data]
7683
else:

src/hayhooks/server/utils/create_valid_type.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from inspect import isclass
22
from types import GenericAlias
3-
from typing import Dict, Union, get_args, get_origin, get_type_hints
4-
5-
from typing_extensions import TypedDict
3+
from typing import Dict, Optional, Union, get_args, get_origin, get_type_hints
64

75

86
def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
@@ -26,7 +24,13 @@ def _handle_generics(t_) -> GenericAlias:
2624
else:
2725
result = t
2826
child_typing.append(result)
29-
return GenericAlias(get_origin(t_), tuple(child_typing))
27+
28+
if len(child_typing) == 2 and child_typing[1] is type(None):
29+
# because TypedDict can't handle union types with None
30+
# rewrite them as Optional[type]
31+
return Optional[child_typing[0]]
32+
else:
33+
return GenericAlias(get_origin(t_), tuple(child_typing))
3034

3135
if isclass(type_):
3236
new_type = {}
@@ -35,8 +39,6 @@ def _handle_generics(t_) -> GenericAlias:
3539
new_type[arg_name] = _handle_generics(arg_type)
3640
else:
3741
new_type[arg_name] = arg_type
38-
if new_type:
39-
return TypedDict(type_.__name__, new_type)
4042

4143
return type_
4244

0 commit comments

Comments
 (0)