1
1
from collections .abc import Callable
2
2
from dataclasses import dataclass
3
- from decimal import Decimal
4
3
from typing import Any
5
-
4
+ import duckdb
5
+ import pyarrow as pa
6
6
import sqlglot
7
7
import sqlglot .expressions
8
8
import sqlglot .optimizer .simplify
@@ -19,36 +19,82 @@ class BaseFieldInfo:
19
19
20
20
21
21
@dataclass
22
- class RangeFieldInfo [ T : Any ] (BaseFieldInfo ):
22
+ class RangeFieldInfo (BaseFieldInfo ):
23
23
"""
24
24
Information about a field that has a min and max value.
25
25
"""
26
26
27
- min_value : T
28
- max_value : T
27
+ min_value : pa . Scalar
28
+ max_value : pa . Scalar
29
29
30
30
31
31
@dataclass
32
- class SetFieldInfo [ T : Any ] (BaseFieldInfo ):
32
+ class SetFieldInfo (BaseFieldInfo ):
33
33
"""
34
34
Information about a field where the set of values are known.
35
35
The information about what values that are contained can produce
36
36
false positives.
37
37
"""
38
38
39
39
values : set [
40
- T
40
+ pa . Scalar
41
41
] # Set of values that are known to be present in the field, false positives are okay.
42
42
43
43
44
- AnyFieldInfo = (
45
- SetFieldInfo [Decimal ]
46
- | SetFieldInfo [float ]
47
- | SetFieldInfo [str ]
48
- | SetFieldInfo [int ]
49
- | RangeFieldInfo [int ]
50
- | RangeFieldInfo [None ]
51
- )
44
+ AnyFieldInfo = SetFieldInfo | RangeFieldInfo
45
+
46
+
47
+ def _scalar_value_op (
48
+ a : pa .Scalar , b : pa .Scalar , op : Callable [[Any , Any ], bool ]
49
+ ) -> bool :
50
+ assert not pa .types .is_null (a .type ), (
51
+ f"Expected a non-null scalar value, got { a } of type { a .type } "
52
+ )
53
+ assert not pa .types .is_null (b .type ), (
54
+ f"Expected a non-null scalar value, got { b } of type { b .type } "
55
+ )
56
+
57
+ # If we have integers or floats we can do that comparision regardless of their types.
58
+ if pa .types .is_integer (a .type ) and pa .types .is_integer (b .type ):
59
+ return op (a .as_py (), b .as_py ())
60
+
61
+ if pa .types .is_floating (a .type ) and pa .types .is_floating (b .type ):
62
+ return op (a .as_py (), b .as_py ())
63
+
64
+ if pa .types .is_string (a .type ) and pa .types .is_string (b .type ):
65
+ return op (a .as_py (), b .as_py ())
66
+
67
+ if pa .types .is_boolean (a .type ) and pa .types .is_boolean (b .type ):
68
+ return op (a .as_py (), b .as_py ())
69
+
70
+ if pa .types .is_decimal (a .type ) and pa .types .is_decimal (b .type ):
71
+ return op (a .as_py (), b .as_py ())
72
+
73
+ assert type (a ) is type (b ), (
74
+ f"Expected same type for comparison, got { type (a )} and { type (b )} "
75
+ )
76
+
77
+ return op (a .as_py (), b .as_py ())
78
+
79
+
80
+ def _scalar_value_lte (a : pa .Scalar , b : pa .Scalar ) -> bool :
81
+ return _scalar_value_op (a , b , lambda x , y : x <= y )
82
+
83
+
84
+ def _scalar_value_lt (a : pa .Scalar , b : pa .Scalar ) -> bool :
85
+ return _scalar_value_op (a , b , lambda x , y : x < y )
86
+
87
+
88
+ def _scalar_value_gt (a : pa .Scalar , b : pa .Scalar ) -> bool :
89
+ return _scalar_value_op (a , b , lambda x , y : x > y )
90
+
91
+
92
+ def _scalar_value_gte (a : pa .Scalar , b : pa .Scalar ) -> bool :
93
+ return _scalar_value_op (a , b , lambda x , y : x >= y )
94
+
95
+
96
+ def _scalar_value_eq (a : pa .Scalar , b : pa .Scalar ) -> bool :
97
+ return _scalar_value_op (a , b , lambda x , y : x == y )
52
98
53
99
54
100
FileFieldInfo = dict [str , AnyFieldInfo ]
@@ -93,18 +139,29 @@ def _eval_predicate(
93
139
if not isinstance (node .left , sqlglot .expressions .Column ):
94
140
return None
95
141
142
+ if node .right .find (sqlglot .expressions .Column ) is not None :
143
+ # Can't evaluate this since it has a right hand column ref, ideally
144
+ # this should be removed further up.
145
+ return None
146
+
96
147
# The thing on the right side should be something that can be evaluated against a range.
97
148
# ideally, its going to be a
98
- assert isinstance (
99
- node .right ,
100
- sqlglot .expressions .Literal
101
- | sqlglot .expressions .Null
102
- | sqlglot .expressions .Neg ,
103
- ), (
104
- f"Expected a literal or null on righthand side of predicate { node } got a { type (node .right )} "
105
- )
149
+ if True : # isinstance(node.right, sqlglot.expressions.Cast):
150
+ connection = duckdb .connect (":memory:" )
151
+ value_result = connection .execute (
152
+ f"select { node .right .sql ('duckdb' )} "
153
+ ).arrow ()
154
+ assert value_result .num_rows == 1 , (
155
+ f"Expected a single row result from cast, got { value_result .num_rows } rows"
156
+ )
157
+ assert value_result .num_columns == 1 , (
158
+ f"Expected a single column result from cast, got { value_result .num_columns } columns"
159
+ )
106
160
107
- right_val = node .right .to_py ()
161
+ right_val = value_result .column (0 )[0 ]
162
+ # This is an interesting behavior, null is returned with an int32 type.
163
+ if type (right_val ) is pa .Int32Scalar and right_val .as_py () is None :
164
+ right_val = pa .scalar (None , type = pa .null ())
108
165
109
166
left_val = node .left
110
167
assert isinstance (left_val , sqlglot .expressions .Column ), (
@@ -117,17 +174,19 @@ def _eval_predicate(
117
174
118
175
field_info = file_info .get (referenced_field_name )
119
176
177
+ # Right now if the field is not present in the file,
178
+ # just note that we couldn't evaluate the expression.
120
179
if field_info is None :
121
180
return None
122
181
123
182
if isinstance (field_info , SetFieldInfo ):
124
183
match type (node ):
125
184
case sqlglot .expressions .EQ :
126
- if right_val is None :
185
+ if pa . types . is_null ( right_val . type ) :
127
186
return False
128
187
return right_val in field_info .values
129
188
case sqlglot .expressions .NEQ :
130
- if right_val is None :
189
+ if pa . types . is_null ( right_val . type ) :
131
190
return False
132
191
return right_val not in field_info .values
133
192
case _:
@@ -136,44 +195,70 @@ def _eval_predicate(
136
195
)
137
196
138
197
if type (node ) is sqlglot .expressions .NullSafeNEQ :
139
- if right_val is not None and field_info .has_non_nulls is False :
198
+ if (
199
+ not pa .types .is_null (right_val .type )
200
+ and field_info .has_non_nulls is False
201
+ ):
140
202
return True
141
- return not (field_info .min_value == field_info .max_value == right_val )
203
+
204
+ if pa .types .is_null (right_val .type ):
205
+ return field_info .has_non_nulls
206
+
207
+ return not (
208
+ _scalar_value_eq (field_info .min_value , field_info .max_value )
209
+ and _scalar_value_eq (field_info .min_value , right_val )
210
+ )
211
+
142
212
elif type (node ) is sqlglot .expressions .NullSafeEQ :
143
- if right_val is None and field_info .has_non_nulls :
213
+ if pa . types . is_null ( right_val . type ) and field_info .has_non_nulls :
144
214
return True
145
215
if field_info .min_value is None or field_info .max_value is None :
146
216
return False
147
- assert right_val is not None
148
- return field_info .min_value <= right_val <= field_info .max_value
217
+ assert not pa .types .is_null (right_val .type )
218
+ return _scalar_value_lte (
219
+ field_info .min_value , right_val
220
+ ) and _scalar_value_lte (right_val , field_info .max_value )
149
221
150
222
if field_info .min_value is None or field_info .max_value is None :
151
223
return False
152
224
153
- if right_val is None :
225
+ if pa . types . is_null ( right_val . type ) :
154
226
return False
155
227
156
228
match type (node ):
157
229
case sqlglot .expressions .EQ :
158
- return field_info .min_value <= right_val <= field_info .max_value
230
+ return _scalar_value_lte (
231
+ field_info .min_value , right_val
232
+ ) and _scalar_value_lte (right_val , field_info .max_value )
159
233
case sqlglot .expressions .NEQ :
160
- return not (field_info .min_value == field_info .max_value == right_val )
234
+ return not (
235
+ _scalar_value_eq (field_info .min_value , field_info .max_value )
236
+ and _scalar_value_eq (field_info .min_value , right_val )
237
+ )
161
238
case sqlglot .expressions .LT :
162
- return field_info .min_value < right_val
239
+ return _scalar_value_lt ( field_info .min_value , right_val )
163
240
case sqlglot .expressions .LTE :
164
- return field_info .min_value <= right_val
241
+ return _scalar_value_lte ( field_info .min_value , right_val )
165
242
case sqlglot .expressions .GT :
166
- return field_info .max_value > right_val
243
+ return _scalar_value_gt ( field_info .max_value , right_val )
167
244
case sqlglot .expressions .GTE :
168
- return field_info .max_value >= right_val
245
+ return _scalar_value_gte ( field_info .max_value , right_val )
169
246
case sqlglot .expressions .NullSafeEQ :
170
- if right_val is None and field_info .has_non_nulls :
247
+ if pa . types . is_null ( right_val . type ) and field_info .has_non_nulls :
171
248
return True
172
- return field_info .min_value <= right_val <= field_info .max_value
249
+ return _scalar_value_lte (
250
+ field_info .min_value , right_val
251
+ ) and _scalar_value_lte (right_val , field_info .max_value )
173
252
case sqlglot .expressions .NullSafeNEQ :
174
- if right_val is not None and field_info .has_non_nulls is False :
253
+ if (
254
+ not pa .types .is_null (right_val .type )
255
+ and field_info .has_non_nulls is False
256
+ ):
175
257
return True
176
- return not (field_info .min_value == field_info .max_value == right_val )
258
+ return not (
259
+ _scalar_value_eq (field_info .min_value , field_info .max_value )
260
+ and _scalar_value_eq (field_info .min_value , right_val )
261
+ )
177
262
case _:
178
263
raise ValueError (f"Unsupported operator type: { type (node )} " )
179
264
@@ -234,14 +319,6 @@ def _evaluate_node_in(
234
319
return False
235
320
236
321
for in_exp in node .expressions :
237
- assert isinstance (
238
- in_exp ,
239
- sqlglot .expressions .Literal
240
- | sqlglot .expressions .Neg
241
- | sqlglot .expressions .Null ,
242
- ), (
243
- f"Expected a literal in in side of { node } , got { in_exp } type { type (in_exp )} "
244
- )
245
322
if self ._eval_predicate (
246
323
file_info ,
247
324
sqlglot .expressions .EQ (this = in_val , expression = in_exp ),
@@ -381,9 +458,7 @@ def _evaluate_sql_node(
381
458
382
459
return False
383
460
384
- def get_matching_files (
385
- self , expression : str , * , dialect : str = "duckdb"
386
- ) -> set [str ]:
461
+ def get_matching_files (self , exp : sqlglot .expressions .Expression | str ) -> set [str ]:
387
462
"""
388
463
Get a set of files that match the given SQL expression.
389
464
Args:
@@ -392,15 +467,23 @@ def get_matching_files(
392
467
Returns:
393
468
A set of filenames that match the expression.
394
469
"""
395
- parse_result = sqlglot .parse_one (expression , dialect = dialect )
470
+ if isinstance (exp , str ):
471
+ # Parse the expression if it is a string.
472
+ expression = sqlglot .parse_one (exp , dialect = "duckdb" )
473
+ else :
474
+ expression = exp
475
+
476
+ assert isinstance (expression , sqlglot .expressions .Expression ), (
477
+ f"Expected a sqlglot expression, got { type (expression )} "
478
+ )
396
479
397
480
# Simplify the parsed expression, move all of the literals to the right side
398
- parse_result = sqlglot .optimizer .simplify .simplify (parse_result )
481
+ expression = sqlglot .optimizer .simplify .simplify (expression )
399
482
400
483
matching_files = set ()
401
484
402
485
for filename , file_info in self .files :
403
- eval_result = self ._evaluate_sql_node (parse_result , file_info )
486
+ eval_result = self ._evaluate_sql_node (expression , file_info )
404
487
if eval_result is None or eval_result is True :
405
488
# If the expression evaluates to True or cannot be evaluated, add the file
406
489
# to the result set since the caller will be able to filter the rows further.
0 commit comments