Skip to content

Commit 75f91c3

Browse files
authored
Merge pull request #7 from tellmewyatt/main
Fixed JSON Schema Serialization for Components
2 parents 4b22762 + 96a27e1 commit 75f91c3

File tree

2 files changed

+62
-33
lines changed

2 files changed

+62
-33
lines changed
+19-33
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,7 @@
1-
from typing import get_args, get_origin, List
1+
from pandas import DataFrame
2+
from pydantic import BaseModel, ConfigDict, create_model
23

3-
from pydantic import BaseModel, create_model, ConfigDict
4-
from haystack.dataclasses import Document
5-
6-
7-
class HaystackDocument(BaseModel):
8-
id: str
9-
content: str
4+
from hayhooks.server.utils.create_valid_type import handle_unsupported_types
105

116

127
class PipelineDefinition(BaseModel):
@@ -29,13 +24,16 @@ def get_request_model(pipeline_name: str, pipeline_inputs):
2924
config = ConfigDict(arbitrary_types_allowed=True)
3025

3126
for component_name, inputs in pipeline_inputs.items():
32-
3327
component_model = {}
3428
for name, typedef in inputs.items():
35-
component_model[name] = (typedef["type"], typedef.get("default_value", ...))
36-
request_model[component_name] = (create_model('ComponentParams', **component_model, __config__=config), ...)
29+
input_type = handle_unsupported_types(typedef["type"], {DataFrame: dict})
30+
component_model[name] = (
31+
input_type,
32+
typedef.get("default_value", ...),
33+
)
34+
request_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)
3735

38-
return create_model(f'{pipeline_name.capitalize()}RunRequest', **request_model, __config__=config)
36+
return create_model(f"{pipeline_name.capitalize()}RunRequest", **request_model, __config__=config)
3937

4038

4139
def get_response_model(pipeline_name: str, pipeline_outputs):
@@ -49,44 +47,32 @@ def get_response_model(pipeline_name: str, pipeline_outputs):
4947
"""
5048
response_model = {}
5149
config = ConfigDict(arbitrary_types_allowed=True)
52-
5350
for component_name, outputs in pipeline_outputs.items():
5451
component_model = {}
5552
for name, typedef in outputs.items():
5653
output_type = typedef["type"]
57-
if get_origin(output_type) == list and get_args(output_type)[0] == Document:
58-
component_model[name] = (List[HaystackDocument], ...)
59-
else:
60-
component_model[name] = (typedef["type"], ...)
61-
response_model[component_name] = (create_model('ComponentParams', **component_model, __config__=config), ...)
54+
component_model[name] = (handle_unsupported_types(output_type, {DataFrame: dict}), ...)
55+
response_model[component_name] = (create_model("ComponentParams", **component_model, __config__=config), ...)
6256

63-
return create_model(f'{pipeline_name.capitalize()}RunResponse', **response_model, __config__=config)
57+
return create_model(f"{pipeline_name.capitalize()}RunResponse", **response_model, __config__=config)
6458

6559

6660
def convert_component_output(component_output):
6761
"""
62+
Converts outputs from a component as a dict so that it can be validated against response model
63+
6864
Component output has this form:
6965
7066
"documents":[
7167
{"id":"818170...", "content":"RapidAPI for Mac is a full-featured HTTP client."}
7268
]
7369
74-
We inspect the output and convert haystack.Document into the HaystackDocument pydantic model as needed
7570
"""
7671
result = {}
7772
for output_name, data in component_output.items():
78-
# Empty containers, None values, empty strings and the likes: do nothing
79-
if not data:
80-
result[output_name] = data
81-
82-
# Output contains a list of Document
83-
if type(data) is list and type(data[0]) is Document:
84-
result[output_name] = [HaystackDocument(id=d.id, content=d.content) for d in data]
85-
# Output is a single Document
86-
elif type(data) is Document:
87-
result[output_name] = HaystackDocument(id=data.id, content=data.content or "")
88-
# Any other type: do nothing
73+
get_value = lambda data: data.to_dict()["init_parameters"] if hasattr(data, "to_dict") else data
74+
if type(data) is list:
75+
result[output_name] = [get_value(d) for d in data]
8976
else:
90-
result[output_name] = data
91-
77+
result[output_name] = get_value(data)
9278
return result
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from inspect import isclass
2+
from types import GenericAlias
3+
from typing import Dict, Union, get_args, get_origin, get_type_hints
4+
5+
from typing_extensions import TypedDict
6+
7+
8+
def handle_unsupported_types(type_: type, types_mapping: Dict[type, type]) -> Union[GenericAlias, type]:
9+
"""
10+
Recursively handle types that are not supported by Pydantic by replacing them with the given types mapping.
11+
12+
:param type_: Type to replace if not supported
13+
:param types_mapping: Mapping of types to replace
14+
"""
15+
16+
def _handle_generics(t_) -> GenericAlias:
17+
"""
18+
Handle generics recursively
19+
"""
20+
child_typing = []
21+
for t in get_args(t_):
22+
if t in types_mapping:
23+
result = types_mapping[t]
24+
elif isclass(t):
25+
result = handle_unsupported_types(t, types_mapping)
26+
else:
27+
result = t
28+
child_typing.append(result)
29+
return GenericAlias(get_origin(t_), tuple(child_typing))
30+
31+
if isclass(type_):
32+
new_type = {}
33+
for arg_name, arg_type in get_type_hints(type_).items():
34+
if get_args(arg_type):
35+
new_type[arg_name] = _handle_generics(arg_type)
36+
else:
37+
new_type[arg_name] = arg_type
38+
if new_type:
39+
return TypedDict(type_.__name__, new_type)
40+
41+
return type_
42+
43+
return _handle_generics(type_)

0 commit comments

Comments
 (0)