Skip to content

Commit 58e276a

Browse files
committed
validate array dimension
1 parent c045390 commit 58e276a

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

pyobvector/schema/array.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@ def __init__(self, item_type: Union[TypeEngine, type]):
2222
if isinstance(item_type, type):
2323
item_type = item_type()
2424
self.item_type = item_type
25+
if isinstance(item_type, ARRAY):
26+
self.dim = item_type.dim + 1
27+
else:
28+
self.dim = 1
29+
if self.dim > 6:
30+
raise ValueError("Maximum nesting level of 6 exceeded")
2531

2632
def get_col_spec(self, **kw): # pylint: disable=unused-argument
2733
"""Parse to array data type definition in text SQL."""
@@ -31,6 +37,20 @@ def get_col_spec(self, **kw): # pylint: disable=unused-argument
3137
base_type = str(self.item_type)
3238
return f"ARRAY({base_type})"
3339

40+
def _get_list_depth(self, value: Any) -> int:
41+
if not isinstance(value, list):
42+
return 0
43+
max_depth = 0
44+
for element in value:
45+
current_depth = self._get_list_depth(element)
46+
if current_depth > max_depth:
47+
max_depth = current_depth
48+
return 1 + max_depth
49+
50+
def _validate_dimension(self, value: list[Any]):
51+
arr_depth = self._get_list_depth(value)
52+
assert arr_depth == self.dim, "Array dimension mismatch, expected {}, got {}".format(self.dim, arr_depth)
53+
3454
def bind_processor(self, dialect):
3555
item_type = self.item_type
3656
while isinstance(item_type, ARRAY):
@@ -39,7 +59,10 @@ def bind_processor(self, dialect):
3959
item_proc = item_type.dialect_impl(dialect).bind_processor(dialect)
4060

4161
def process(value: Optional[Sequence[Any] | str]) -> Optional[str]:
42-
if value is None or isinstance(value, str):
62+
if value is None:
63+
return None
64+
if isinstance(value, str):
65+
self._validate_dimension(json.loads(value))
4366
return value
4467

4568
def convert(val):
@@ -50,6 +73,7 @@ def convert(val):
5073
return val
5174

5275
processed = convert(value)
76+
self._validate_dimension(processed)
5377
return json.dumps(processed)
5478

5579
return process
@@ -123,6 +147,12 @@ def __init__(self, item_type: Union[TypeEngine, type]):
123147
item_type = item_type()
124148

125149
assert not isinstance(item_type, ARRAY), "The item_type of NestedArray should not be an ARRAY type"
126-
self.item_type = item_type
150+
151+
nested_type = item_type
152+
for _ in range(dim):
153+
nested_type = ARRAY(nested_type)
154+
155+
self.item_type = nested_type.item_type
156+
self.dim = dim
127157

128158
return NestedArray

0 commit comments

Comments
 (0)