Skip to content

Commit c8d53b3

Browse files
authored
fix: Adjust serialization to handle PEP-585 generic types (#7690)
* Adjust serialization to handle PEP-585 generic types * Add reno note * Simplify * PEP 585 serialization handling in sys.version_info < (3, 9)
1 parent 96b9d3e commit c8d53b3

File tree

3 files changed

+51
-4
lines changed

3 files changed

+51
-4
lines changed

haystack/utils/type_serialization.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
import importlib
66
import inspect
77
import sys
8-
from typing import Any, get_origin
8+
import typing
9+
from typing import Any, get_args, get_origin
910

1011
from haystack import DeserializationError
1112

@@ -32,19 +33,23 @@ def serialize_type(target: Any) -> str:
3233
# Determine if the target is a type or an instance of a typing object
3334
is_type_or_typing = isinstance(target, type) or bool(get_origin(target))
3435
type_obj = target if is_type_or_typing else type(target)
35-
module = inspect.getmodule(type_obj)
3636
type_obj_repr = repr(type_obj)
3737

3838
if type_obj_repr.startswith("typing."):
3939
# e.g., typing.List[int] -> List[int], we'll add the module below
4040
type_name = type_obj_repr.split(".", 1)[1]
41+
elif origin := get_origin(type_obj): # get the origin (base type of the parameterized generic type)
42+
# get the arguments of the generic type
43+
args = get_args(type_obj)
44+
args_repr = ", ".join(serialize_type(arg) for arg in args)
45+
type_name = f"{origin.__name__}[{args_repr}]"
4146
elif hasattr(type_obj, "__name__"):
4247
type_name = type_obj.__name__
4348
else:
4449
# If type cannot be serialized, raise an error
4550
raise ValueError(f"Could not serialize type: {type_obj_repr}")
4651

47-
# Construct the full path with module name if available
52+
module = inspect.getmodule(type_obj)
4853
if module and hasattr(module, "__name__"):
4954
if module.__name__ == "builtins":
5055
# omit the module name for builtins, it just clutters the output
@@ -73,6 +78,14 @@ def deserialize_type(type_str: str) -> Any:
7378
If the type cannot be deserialized due to missing module or type.
7479
"""
7580

81+
type_mapping = {
82+
list: typing.List,
83+
dict: typing.Dict,
84+
set: typing.Set,
85+
tuple: typing.Tuple,
86+
frozenset: typing.FrozenSet,
87+
}
88+
7689
def parse_generic_args(args_str):
7790
args = []
7891
bracket_count = 0
@@ -104,7 +117,10 @@ def parse_generic_args(args_str):
104117
generic_args = tuple(deserialize_type(arg) for arg in parse_generic_args(generics_str))
105118

106119
# Reconstruct
107-
return main_type[generic_args]
120+
if sys.version_info >= (3, 9) or repr(main_type).startswith("typing."):
121+
return main_type[generic_args]
122+
else:
123+
return type_mapping[main_type][generic_args] # type: ignore
108124

109125
else:
110126
# Handle non-generics
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
Improves/fixes type serialization of PEP 585 types (e.g. list[Document], and their nested version). This improvement enables better serialization of generics and nested types and improves/fixes matching of list[X] and List[X] types in component connections after serialization.

test/utils/test_type_serialization.py

+27
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44
import copy
5+
import sys
56
import typing
67
from typing import List, Dict
78

9+
import pytest
10+
811
from haystack.dataclasses import ChatMessage
912
from haystack.components.routers.conditional_router import serialize_type, deserialize_type
1013

@@ -24,6 +27,18 @@ def test_output_type_serialization():
2427
assert serialize_type(ChatMessage.from_user("ciao")) == "haystack.dataclasses.chat_message.ChatMessage"
2528

2629

30+
@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP 585 types are only available in Python 3.9+")
31+
def test_output_type_serialization_pep585():
32+
# Only Python 3.9+ supports PEP 585 types and can serialize them
33+
# PEP 585 types
34+
assert serialize_type(list[int]) == "list[int]"
35+
assert serialize_type(list[list[int]]) == "list[list[int]]"
36+
37+
# more nested types
38+
assert serialize_type(list[list[list[int]]]) == "list[list[list[int]]]"
39+
assert serialize_type(dict[str, int]) == "dict[str, int]"
40+
41+
2742
def test_output_type_deserialization():
2843
assert deserialize_type("str") == str
2944
assert deserialize_type("typing.List[int]") == typing.List[int]
@@ -38,3 +53,15 @@ def test_output_type_deserialization():
3853
)
3954
assert deserialize_type("haystack.dataclasses.chat_message.ChatMessage") == ChatMessage
4055
assert deserialize_type("int") == int
56+
57+
58+
def test_output_type_deserialization_pep585():
59+
is_pep585 = sys.version_info >= (3, 9)
60+
61+
# Although only Python 3.9+ supports PEP 585 types, we can still deserialize them in older Python versions
62+
# as their typing equivalents
63+
assert deserialize_type("list[int]") == list[int] if is_pep585 else List[int]
64+
assert deserialize_type("dict[str, int]") == dict[str, int] if is_pep585 else Dict[str, int]
65+
# more nested types
66+
assert deserialize_type("list[list[int]]") == list[list[int]] if is_pep585 else List[List[int]]
67+
assert deserialize_type("list[list[list[int]]]") == list[list[list[int]]] if is_pep585 else List[List[List[int]]]

0 commit comments

Comments
 (0)