5
5
import importlib
6
6
import inspect
7
7
import sys
8
- from typing import Any , get_origin
8
+ import typing
9
+ from typing import Any , get_args , get_origin
9
10
10
11
from haystack import DeserializationError
11
12
@@ -32,19 +33,23 @@ def serialize_type(target: Any) -> str:
32
33
# Determine if the target is a type or an instance of a typing object
33
34
is_type_or_typing = isinstance (target , type ) or bool (get_origin (target ))
34
35
type_obj = target if is_type_or_typing else type (target )
35
- module = inspect .getmodule (type_obj )
36
36
type_obj_repr = repr (type_obj )
37
37
38
38
if type_obj_repr .startswith ("typing." ):
39
39
# e.g., typing.List[int] -> List[int], we'll add the module below
40
40
type_name = type_obj_repr .split ("." , 1 )[1 ]
41
+ elif origin := get_origin (type_obj ): # get the origin (base type of the parameterized generic type)
42
+ # get the arguments of the generic type
43
+ args = get_args (type_obj )
44
+ args_repr = ", " .join (serialize_type (arg ) for arg in args )
45
+ type_name = f"{ origin .__name__ } [{ args_repr } ]"
41
46
elif hasattr (type_obj , "__name__" ):
42
47
type_name = type_obj .__name__
43
48
else :
44
49
# If type cannot be serialized, raise an error
45
50
raise ValueError (f"Could not serialize type: { type_obj_repr } " )
46
51
47
- # Construct the full path with module name if available
52
+ module = inspect . getmodule ( type_obj )
48
53
if module and hasattr (module , "__name__" ):
49
54
if module .__name__ == "builtins" :
50
55
# omit the module name for builtins, it just clutters the output
@@ -73,6 +78,14 @@ def deserialize_type(type_str: str) -> Any:
73
78
If the type cannot be deserialized due to missing module or type.
74
79
"""
75
80
81
+ type_mapping = {
82
+ list : typing .List ,
83
+ dict : typing .Dict ,
84
+ set : typing .Set ,
85
+ tuple : typing .Tuple ,
86
+ frozenset : typing .FrozenSet ,
87
+ }
88
+
76
89
def parse_generic_args (args_str ):
77
90
args = []
78
91
bracket_count = 0
@@ -104,7 +117,10 @@ def parse_generic_args(args_str):
104
117
generic_args = tuple (deserialize_type (arg ) for arg in parse_generic_args (generics_str ))
105
118
106
119
# Reconstruct
107
- return main_type [generic_args ]
120
+ if sys .version_info >= (3 , 9 ) or repr (main_type ).startswith ("typing." ):
121
+ return main_type [generic_args ]
122
+ else :
123
+ return type_mapping [main_type ][generic_args ] # type: ignore
108
124
109
125
else :
110
126
# Handle non-generics
0 commit comments