Skip to content

Commit bae5b35

Browse files
committed
fix: change to arrow scalar types to accommodate more data types
1 parent 9cb7d79 commit bae5b35

File tree

5 files changed

+251
-92
lines changed

5 files changed

+251
-92
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ authors = [
77
]
88
dependencies = [
99
"sqlglot>=26.33.0",
10+
"duckdb>=1.3.2",
11+
"pyarrow>=20.0.0",
1012
]
1113
readme = "README.md"
1214
requires-python = ">= 3.11"

requirements-dev.lock

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
-e file:.
1313
coverage==7.9.2
1414
# via pytest-cov
15+
duckdb==1.3.2
16+
# via query-farm-sql-scan-planning
1517
filelock==3.18.0
1618
# via pytest-mypy
1719
iniconfig==2.1.0
@@ -27,6 +29,8 @@ pathspec==0.12.1
2729
pluggy==1.6.0
2830
# via pytest
2931
# via pytest-cov
32+
pyarrow==20.0.0
33+
# via query-farm-sql-scan-planning
3034
pygments==2.19.2
3135
# via pytest
3236
pytest==8.4.1
@@ -37,7 +41,7 @@ pytest-cov==6.2.1
3741
pytest-env==1.1.5
3842
pytest-mypy==1.0.1
3943
ruff==0.12.2
40-
sqlglot==26.33.0
44+
sqlglot==27.0.0
4145
# via query-farm-sql-scan-planning
4246
typing-extensions==4.14.1
4347
# via mypy

requirements.lock

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,9 @@
1010
# universal: false
1111

1212
-e file:.
13-
sqlglot==26.33.0
13+
duckdb==1.3.2
14+
# via query-farm-sql-scan-planning
15+
pyarrow==20.0.0
16+
# via query-farm-sql-scan-planning
17+
sqlglot==27.0.0
1418
# via query-farm-sql-scan-planning

src/query_farm_sql_scan_planning/planner.py

Lines changed: 139 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from collections.abc import Callable
22
from dataclasses import dataclass
3-
from decimal import Decimal
43
from typing import Any
5-
4+
import duckdb
5+
import pyarrow as pa
66
import sqlglot
77
import sqlglot.expressions
88
import sqlglot.optimizer.simplify
@@ -19,36 +19,82 @@ class BaseFieldInfo:
1919

2020

2121
@dataclass
22-
class RangeFieldInfo[T: Any](BaseFieldInfo):
22+
class RangeFieldInfo(BaseFieldInfo):
2323
"""
2424
Information about a field that has a min and max value.
2525
"""
2626

27-
min_value: T
28-
max_value: T
27+
min_value: pa.Scalar
28+
max_value: pa.Scalar
2929

3030

3131
@dataclass
32-
class SetFieldInfo[T: Any](BaseFieldInfo):
32+
class SetFieldInfo(BaseFieldInfo):
3333
"""
3434
Information about a field where the set of values are known.
3535
The information about what values that are contained can produce
3636
false positives.
3737
"""
3838

3939
values: set[
40-
T
40+
pa.Scalar
4141
] # Set of values that are known to be present in the field, false positives are okay.
4242

4343

44-
AnyFieldInfo = (
45-
SetFieldInfo[Decimal]
46-
| SetFieldInfo[float]
47-
| SetFieldInfo[str]
48-
| SetFieldInfo[int]
49-
| RangeFieldInfo[int]
50-
| RangeFieldInfo[None]
51-
)
44+
AnyFieldInfo = SetFieldInfo | RangeFieldInfo
45+
46+
47+
def _scalar_value_op(
48+
a: pa.Scalar, b: pa.Scalar, op: Callable[[Any, Any], bool]
49+
) -> bool:
50+
assert not pa.types.is_null(a.type), (
51+
f"Expected a non-null scalar value, got {a} of type {a.type}"
52+
)
53+
assert not pa.types.is_null(b.type), (
54+
f"Expected a non-null scalar value, got {b} of type {b.type}"
55+
)
56+
57+
# If we have integers or floats we can do that comparision regardless of their types.
58+
if pa.types.is_integer(a.type) and pa.types.is_integer(b.type):
59+
return op(a.as_py(), b.as_py())
60+
61+
if pa.types.is_floating(a.type) and pa.types.is_floating(b.type):
62+
return op(a.as_py(), b.as_py())
63+
64+
if pa.types.is_string(a.type) and pa.types.is_string(b.type):
65+
return op(a.as_py(), b.as_py())
66+
67+
if pa.types.is_boolean(a.type) and pa.types.is_boolean(b.type):
68+
return op(a.as_py(), b.as_py())
69+
70+
if pa.types.is_decimal(a.type) and pa.types.is_decimal(b.type):
71+
return op(a.as_py(), b.as_py())
72+
73+
assert type(a) is type(b), (
74+
f"Expected same type for comparison, got {type(a)} and {type(b)}"
75+
)
76+
77+
return op(a.as_py(), b.as_py())
78+
79+
80+
def _scalar_value_lte(a: pa.Scalar, b: pa.Scalar) -> bool:
81+
return _scalar_value_op(a, b, lambda x, y: x <= y)
82+
83+
84+
def _scalar_value_lt(a: pa.Scalar, b: pa.Scalar) -> bool:
85+
return _scalar_value_op(a, b, lambda x, y: x < y)
86+
87+
88+
def _scalar_value_gt(a: pa.Scalar, b: pa.Scalar) -> bool:
89+
return _scalar_value_op(a, b, lambda x, y: x > y)
90+
91+
92+
def _scalar_value_gte(a: pa.Scalar, b: pa.Scalar) -> bool:
93+
return _scalar_value_op(a, b, lambda x, y: x >= y)
94+
95+
96+
def _scalar_value_eq(a: pa.Scalar, b: pa.Scalar) -> bool:
97+
return _scalar_value_op(a, b, lambda x, y: x == y)
5298

5399

54100
FileFieldInfo = dict[str, AnyFieldInfo]
@@ -93,18 +139,29 @@ def _eval_predicate(
93139
if not isinstance(node.left, sqlglot.expressions.Column):
94140
return None
95141

142+
if node.right.find(sqlglot.expressions.Column) is not None:
143+
# Can't evaluate this since it has a right hand column ref, ideally
144+
# this should be removed further up.
145+
return None
146+
96147
# The thing on the right side should be something that can be evaluated against a range.
97148
# ideally, its going to be a
98-
assert isinstance(
99-
node.right,
100-
sqlglot.expressions.Literal
101-
| sqlglot.expressions.Null
102-
| sqlglot.expressions.Neg,
103-
), (
104-
f"Expected a literal or null on righthand side of predicate {node} got a {type(node.right)}"
105-
)
149+
if True: # isinstance(node.right, sqlglot.expressions.Cast):
150+
connection = duckdb.connect(":memory:")
151+
value_result = connection.execute(
152+
f"select {node.right.sql('duckdb')}"
153+
).arrow()
154+
assert value_result.num_rows == 1, (
155+
f"Expected a single row result from cast, got {value_result.num_rows} rows"
156+
)
157+
assert value_result.num_columns == 1, (
158+
f"Expected a single column result from cast, got {value_result.num_columns} columns"
159+
)
106160

107-
right_val = node.right.to_py()
161+
right_val = value_result.column(0)[0]
162+
# This is an interesting behavior, null is returned with an int32 type.
163+
if type(right_val) is pa.Int32Scalar and right_val.as_py() is None:
164+
right_val = pa.scalar(None, type=pa.null())
108165

109166
left_val = node.left
110167
assert isinstance(left_val, sqlglot.expressions.Column), (
@@ -117,17 +174,19 @@ def _eval_predicate(
117174

118175
field_info = file_info.get(referenced_field_name)
119176

177+
# Right now if the field is not present in the file,
178+
# just note that we couldn't evaluate the expression.
120179
if field_info is None:
121180
return None
122181

123182
if isinstance(field_info, SetFieldInfo):
124183
match type(node):
125184
case sqlglot.expressions.EQ:
126-
if right_val is None:
185+
if pa.types.is_null(right_val.type):
127186
return False
128187
return right_val in field_info.values
129188
case sqlglot.expressions.NEQ:
130-
if right_val is None:
189+
if pa.types.is_null(right_val.type):
131190
return False
132191
return right_val not in field_info.values
133192
case _:
@@ -136,44 +195,70 @@ def _eval_predicate(
136195
)
137196

138197
if type(node) is sqlglot.expressions.NullSafeNEQ:
139-
if right_val is not None and field_info.has_non_nulls is False:
198+
if (
199+
not pa.types.is_null(right_val.type)
200+
and field_info.has_non_nulls is False
201+
):
140202
return True
141-
return not (field_info.min_value == field_info.max_value == right_val)
203+
204+
if pa.types.is_null(right_val.type):
205+
return field_info.has_non_nulls
206+
207+
return not (
208+
_scalar_value_eq(field_info.min_value, field_info.max_value)
209+
and _scalar_value_eq(field_info.min_value, right_val)
210+
)
211+
142212
elif type(node) is sqlglot.expressions.NullSafeEQ:
143-
if right_val is None and field_info.has_non_nulls:
213+
if pa.types.is_null(right_val.type) and field_info.has_non_nulls:
144214
return True
145215
if field_info.min_value is None or field_info.max_value is None:
146216
return False
147-
assert right_val is not None
148-
return field_info.min_value <= right_val <= field_info.max_value
217+
assert not pa.types.is_null(right_val.type)
218+
return _scalar_value_lte(
219+
field_info.min_value, right_val
220+
) and _scalar_value_lte(right_val, field_info.max_value)
149221

150222
if field_info.min_value is None or field_info.max_value is None:
151223
return False
152224

153-
if right_val is None:
225+
if pa.types.is_null(right_val.type):
154226
return False
155227

156228
match type(node):
157229
case sqlglot.expressions.EQ:
158-
return field_info.min_value <= right_val <= field_info.max_value
230+
return _scalar_value_lte(
231+
field_info.min_value, right_val
232+
) and _scalar_value_lte(right_val, field_info.max_value)
159233
case sqlglot.expressions.NEQ:
160-
return not (field_info.min_value == field_info.max_value == right_val)
234+
return not (
235+
_scalar_value_eq(field_info.min_value, field_info.max_value)
236+
and _scalar_value_eq(field_info.min_value, right_val)
237+
)
161238
case sqlglot.expressions.LT:
162-
return field_info.min_value < right_val
239+
return _scalar_value_lt(field_info.min_value, right_val)
163240
case sqlglot.expressions.LTE:
164-
return field_info.min_value <= right_val
241+
return _scalar_value_lte(field_info.min_value, right_val)
165242
case sqlglot.expressions.GT:
166-
return field_info.max_value > right_val
243+
return _scalar_value_gt(field_info.max_value, right_val)
167244
case sqlglot.expressions.GTE:
168-
return field_info.max_value >= right_val
245+
return _scalar_value_gte(field_info.max_value, right_val)
169246
case sqlglot.expressions.NullSafeEQ:
170-
if right_val is None and field_info.has_non_nulls:
247+
if pa.types.is_null(right_val.type) and field_info.has_non_nulls:
171248
return True
172-
return field_info.min_value <= right_val <= field_info.max_value
249+
return _scalar_value_lte(
250+
field_info.min_value, right_val
251+
) and _scalar_value_lte(right_val, field_info.max_value)
173252
case sqlglot.expressions.NullSafeNEQ:
174-
if right_val is not None and field_info.has_non_nulls is False:
253+
if (
254+
not pa.types.is_null(right_val.type)
255+
and field_info.has_non_nulls is False
256+
):
175257
return True
176-
return not (field_info.min_value == field_info.max_value == right_val)
258+
return not (
259+
_scalar_value_eq(field_info.min_value, field_info.max_value)
260+
and _scalar_value_eq(field_info.min_value, right_val)
261+
)
177262
case _:
178263
raise ValueError(f"Unsupported operator type: {type(node)}")
179264

@@ -234,14 +319,6 @@ def _evaluate_node_in(
234319
return False
235320

236321
for in_exp in node.expressions:
237-
assert isinstance(
238-
in_exp,
239-
sqlglot.expressions.Literal
240-
| sqlglot.expressions.Neg
241-
| sqlglot.expressions.Null,
242-
), (
243-
f"Expected a literal in in side of {node}, got {in_exp} type {type(in_exp)}"
244-
)
245322
if self._eval_predicate(
246323
file_info,
247324
sqlglot.expressions.EQ(this=in_val, expression=in_exp),
@@ -381,9 +458,7 @@ def _evaluate_sql_node(
381458

382459
return False
383460

384-
def get_matching_files(
385-
self, expression: str, *, dialect: str = "duckdb"
386-
) -> set[str]:
461+
def get_matching_files(self, exp: sqlglot.expressions.Expression | str) -> set[str]:
387462
"""
388463
Get a set of files that match the given SQL expression.
389464
Args:
@@ -392,15 +467,23 @@ def get_matching_files(
392467
Returns:
393468
A set of filenames that match the expression.
394469
"""
395-
parse_result = sqlglot.parse_one(expression, dialect=dialect)
470+
if isinstance(exp, str):
471+
# Parse the expression if it is a string.
472+
expression = sqlglot.parse_one(exp, dialect="duckdb")
473+
else:
474+
expression = exp
475+
476+
assert isinstance(expression, sqlglot.expressions.Expression), (
477+
f"Expected a sqlglot expression, got {type(expression)}"
478+
)
396479

397480
# Simplify the parsed expression, move all of the literals to the right side
398-
parse_result = sqlglot.optimizer.simplify.simplify(parse_result)
481+
expression = sqlglot.optimizer.simplify.simplify(expression)
399482

400483
matching_files = set()
401484

402485
for filename, file_info in self.files:
403-
eval_result = self._evaluate_sql_node(parse_result, file_info)
486+
eval_result = self._evaluate_sql_node(expression, file_info)
404487
if eval_result is None or eval_result is True:
405488
# If the expression evaluates to True or cannot be evaluated, add the file
406489
# to the result set since the caller will be able to filter the rows further.

0 commit comments

Comments
 (0)