55import importlib
66import inspect
77import sys
8- from typing import Any , get_origin
8+ import typing
9+ from typing import Any , get_args , get_origin
910
1011from haystack import DeserializationError
1112
@@ -32,19 +33,23 @@ def serialize_type(target: Any) -> str:
3233 # Determine if the target is a type or an instance of a typing object
3334 is_type_or_typing = isinstance (target , type ) or bool (get_origin (target ))
3435 type_obj = target if is_type_or_typing else type (target )
35- module = inspect .getmodule (type_obj )
3636 type_obj_repr = repr (type_obj )
3737
3838 if type_obj_repr .startswith ("typing." ):
3939 # e.g., typing.List[int] -> List[int], we'll add the module below
4040 type_name = type_obj_repr .split ("." , 1 )[1 ]
41+ elif origin := get_origin (type_obj ): # get the origin (base type of the parameterized generic type)
42+ # get the arguments of the generic type
43+ args = get_args (type_obj )
44+ args_repr = ", " .join (serialize_type (arg ) for arg in args )
45+ type_name = f"{ origin .__name__ } [{ args_repr } ]"
4146 elif hasattr (type_obj , "__name__" ):
4247 type_name = type_obj .__name__
4348 else :
4449 # If type cannot be serialized, raise an error
4550 raise ValueError (f"Could not serialize type: { type_obj_repr } " )
4651
47- # Construct the full path with module name if available
52+ module = inspect . getmodule ( type_obj )
4853 if module and hasattr (module , "__name__" ):
4954 if module .__name__ == "builtins" :
5055 # omit the module name for builtins, it just clutters the output
@@ -73,6 +78,14 @@ def deserialize_type(type_str: str) -> Any:
7378 If the type cannot be deserialized due to missing module or type.
7479 """
7580
81+ type_mapping = {
82+ list : typing .List ,
83+ dict : typing .Dict ,
84+ set : typing .Set ,
85+ tuple : typing .Tuple ,
86+ frozenset : typing .FrozenSet ,
87+ }
88+
7689 def parse_generic_args (args_str ):
7790 args = []
7891 bracket_count = 0
@@ -104,7 +117,10 @@ def parse_generic_args(args_str):
104117 generic_args = tuple (deserialize_type (arg ) for arg in parse_generic_args (generics_str ))
105118
106119 # Reconstruct
107- return main_type [generic_args ]
120+ if sys .version_info >= (3 , 9 ) or repr (main_type ).startswith ("typing." ):
121+ return main_type [generic_args ]
122+ else :
123+ return type_mapping [main_type ][generic_args ] # type: ignore
108124
109125 else :
110126 # Handle non-generics
0 commit comments