@@ -22,6 +22,12 @@ def __init__(self, item_type: Union[TypeEngine, type]):
2222 if isinstance (item_type , type ):
2323 item_type = item_type ()
2424 self .item_type = item_type
25+ if isinstance (item_type , ARRAY ):
26+ self .dim = item_type .dim + 1
27+ else :
28+ self .dim = 1
29+ if self .dim > 6 :
30+ raise ValueError ("Maximum nesting level of 6 exceeded" )
2531
2632 def get_col_spec (self , ** kw ): # pylint: disable=unused-argument
2733 """Parse to array data type definition in text SQL."""
@@ -31,6 +37,20 @@ def get_col_spec(self, **kw): # pylint: disable=unused-argument
3137 base_type = str (self .item_type )
3238 return f"ARRAY({ base_type } )"
3339
40+ def _get_list_depth (self , value : Any ) -> int :
41+ if not isinstance (value , list ):
42+ return 0
43+ max_depth = 0
44+ for element in value :
45+ current_depth = self ._get_list_depth (element )
46+ if current_depth > max_depth :
47+ max_depth = current_depth
48+ return 1 + max_depth
49+
50+ def _validate_dimension (self , value : list [Any ]):
51+ arr_depth = self ._get_list_depth (value )
52+ assert arr_depth == self .dim , "Array dimension mismatch, expected {}, got {}" .format (self .dim , arr_depth )
53+
3454 def bind_processor (self , dialect ):
3555 item_type = self .item_type
3656 while isinstance (item_type , ARRAY ):
@@ -39,7 +59,10 @@ def bind_processor(self, dialect):
3959 item_proc = item_type .dialect_impl (dialect ).bind_processor (dialect )
4060
4161 def process (value : Optional [Sequence [Any ] | str ]) -> Optional [str ]:
42- if value is None or isinstance (value , str ):
62+ if value is None :
63+ return None
64+ if isinstance (value , str ):
65+ self ._validate_dimension (json .loads (value ))
4366 return value
4467
4568 def convert (val ):
@@ -50,6 +73,7 @@ def convert(val):
5073 return val
5174
5275 processed = convert (value )
76+ self ._validate_dimension (processed )
5377 return json .dumps (processed )
5478
5579 return process
@@ -123,6 +147,12 @@ def __init__(self, item_type: Union[TypeEngine, type]):
123147 item_type = item_type ()
124148
125149 assert not isinstance (item_type , ARRAY ), "The item_type of NestedArray should not be an ARRAY type"
126- self .item_type = item_type
150+
151+ nested_type = item_type
152+ for _ in range (dim ):
153+ nested_type = ARRAY (nested_type )
154+
155+ self .item_type = nested_type .item_type
156+ self .dim = dim
127157
128158 return NestedArray
0 commit comments