Skip to content

Commit 3a23052

Browse files
committed
convert union type with none to optional
1 parent 0596208 commit 3a23052

File tree

2 files changed

+16
-3
lines changed

2 files changed

+16
-3
lines changed

src/hayhooks/server/pipelines/models.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
2626
for component_name, inputs in pipeline_inputs.items():
2727
component_model = {}
2828
for name, typedef in inputs.items():
29-
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
29+
try:
30+
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
31+
except TypeError as e:
32+
print(f"ERROR at {component_name!r}, {name}: {typedef}")
33+
raise e
3034
component_model[name] = (
3135
input_type,
3236
typedef.get("default_value", ...),
@@ -70,7 +74,10 @@ def convert_component_output(component_output):
7074
"""
7175
result = {}
7276
for output_name, data in component_output.items():
73-
get_value = lambda data: data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data
77+
78+
def get_value(data):
79+
return data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data
80+
7481
if type(data) is list:
7582
result[output_name] = [get_value(d) for d in data]
7683
else:

src/hayhooks/server/utils/create_valid_type.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from inspect import isclass
22
from types import GenericAlias
3-
from typing import Dict, Union, get_args, get_origin, get_type_hints
3+
from typing import Dict, Union, Optional, get_args, get_origin, get_type_hints
44

55
from typing_extensions import TypedDict
66

@@ -36,6 +36,12 @@ def _handle_generics(t_) -> GenericAlias:
3636
else:
3737
new_type[arg_name] = arg_type
3838
if new_type:
39+
# because TypedDict can't handle union types with None
40+
# rewrite them as Optional[type]
41+
for arg_name, arg_type in new_type.items():
42+
type_args = get_args(arg_type)
43+
if len(type_args) == 2 and type_args[1] is type(None):
44+
new_type[arg_name] = Optional[type_args[0]]
3945
return TypedDict(type_.__name__, new_type)
4046

4147
return type_

0 commit comments

Comments
 (0)