@@ -24,7 +24,7 @@ use optd_core::optimizer::Optimizer;
24
24
use optd_core:: rules:: { Rule , RuleMatcher } ;
25
25
26
26
use super :: filter:: simplify_log_expr;
27
- use super :: macros:: define_rule;
27
+ use super :: macros:: { define_rule, define_rule_discriminant } ;
28
28
use crate :: plan_nodes:: {
29
29
ArcDfPlanNode , ArcDfPredNode , ColumnRefPred , DfNodeType , DfPredType , DfReprPlanNode ,
30
30
DfReprPredNode , JoinType , ListPred , LogOpPred , LogOpType , LogicalAgg , LogicalFilter ,
@@ -160,6 +160,87 @@ fn apply_filter_merge(
160
160
vec ! [ new_filter. into_plan_node( ) . into( ) ]
161
161
}
162
162
163
+ // Rule to split predicates in a join condition into those that can be pushed down as filters.
164
+ define_rule ! (
165
+ InnerJoinSplitFilterRule ,
166
+ apply_join_split_filter,
167
+ ( Join ( JoinType :: Inner ) , child_a, child_b)
168
+ ) ;
169
+
170
+ define_rule ! (
171
+ LeftOuterJoinSplitFilterRule ,
172
+ apply_join_split_filter,
173
+ ( Join ( JoinType :: LeftOuter ) , child_a, child_b)
174
+ ) ;
175
+
176
+ fn apply_join_split_filter (
177
+ optimizer : & impl Optimizer < DfNodeType > ,
178
+ binding : ArcDfPlanNode ,
179
+ ) -> Vec < PlanNodeOrGroup < DfNodeType > > {
180
+ println ! ( "Applying JoinSplitFilterRule" ) ;
181
+ let join = LogicalJoin :: from_plan_node ( binding) . unwrap ( ) ;
182
+ let left_child = join. left ( ) ;
183
+ let right_child = join. right ( ) ;
184
+ let join_cond = join. cond ( ) ;
185
+ let join_typ = join. join_type ( ) ;
186
+
187
+ let left_schema_size = optimizer. get_schema_of ( left_child. clone ( ) ) . len ( ) ;
188
+ let right_schema_size = optimizer. get_schema_of ( right_child. clone ( ) ) . len ( ) ;
189
+
190
+ // Conditions that only involve the left relation.
191
+ let mut left_conds = vec ! [ ] ;
192
+ // Conditions that only involve the right relation.
193
+ let mut right_conds = vec ! [ ] ;
194
+ // Conditions that involve both relations.
195
+ let mut keep_conds = vec ! [ ] ;
196
+
197
+ let categorization_fn = |expr : ArcDfPredNode , children : & [ ArcDfPredNode ] | {
198
+ let location = determine_join_cond_dep ( children, left_schema_size, right_schema_size) ;
199
+ match location {
200
+ JoinCondDependency :: Left => left_conds. push ( expr) ,
201
+ JoinCondDependency :: Right => right_conds. push (
202
+ expr. rewrite_column_refs ( |idx| {
203
+ Some ( LogicalJoin :: map_through_join (
204
+ idx,
205
+ left_schema_size,
206
+ right_schema_size,
207
+ ) )
208
+ } )
209
+ . unwrap ( ) ,
210
+ ) ,
211
+ JoinCondDependency :: Both => keep_conds. push ( expr) ,
212
+ JoinCondDependency :: None => {
213
+ unreachable ! ( "join condition should always involve at least one relation" ) ;
214
+ }
215
+ }
216
+ } ;
217
+ categorize_conds ( categorization_fn, join_cond) ;
218
+
219
+ let new_left = if !left_conds. is_empty ( ) {
220
+ let new_filter_node =
221
+ LogicalFilter :: new_unchecked ( left_child, and_expr_list_to_expr ( left_conds) ) ;
222
+ PlanNodeOrGroup :: PlanNode ( new_filter_node. into_plan_node ( ) )
223
+ } else {
224
+ left_child
225
+ } ;
226
+
227
+ let new_right = if !right_conds. is_empty ( ) {
228
+ let new_filter_node =
229
+ LogicalFilter :: new_unchecked ( right_child, and_expr_list_to_expr ( right_conds) ) ;
230
+ PlanNodeOrGroup :: PlanNode ( new_filter_node. into_plan_node ( ) )
231
+ } else {
232
+ right_child
233
+ } ;
234
+
235
+ let new_join = LogicalJoin :: new_unchecked (
236
+ new_left,
237
+ new_right,
238
+ and_expr_list_to_expr ( keep_conds) ,
239
+ * join_typ,
240
+ ) ;
241
+
242
+ vec ! [ new_join. into_plan_node( ) . into( ) ]
243
+ }
163
244
define_rule ! (
164
245
FilterInnerJoinTransposeRule ,
165
246
apply_filter_inner_join_transpose,
@@ -442,6 +523,52 @@ mod tests {
442
523
assert_eq ! ( col_4. value( ) . as_i32( ) , 1 ) ;
443
524
}
444
525
526
+ #[ test]
527
+ fn join_split_filter ( ) {
528
+ let mut test_optimizer = new_test_optimizer ( Arc :: new ( LeftOuterJoinSplitFilterRule :: new ( ) ) ) ;
529
+
530
+ let scan1 = LogicalScan :: new ( "customer" . into ( ) ) ;
531
+
532
+ let scan2 = LogicalScan :: new ( "orders" . into ( ) ) ;
533
+
534
+ let join_cond = LogOpPred :: new (
535
+ LogOpType :: And ,
536
+ vec ! [
537
+ BinOpPred :: new(
538
+ // This one should be pushed to the left child
539
+ ColumnRefPred :: new( 0 ) . into_pred_node( ) ,
540
+ ConstantPred :: int32( 5 ) . into_pred_node( ) ,
541
+ BinOpType :: Eq ,
542
+ )
543
+ . into_pred_node( ) ,
544
+ BinOpPred :: new(
545
+ // This one should be pushed to the right child
546
+ ColumnRefPred :: new( 11 ) . into_pred_node( ) ,
547
+ ConstantPred :: int32( 6 ) . into_pred_node( ) ,
548
+ BinOpType :: Eq ,
549
+ )
550
+ . into_pred_node( ) ,
551
+ BinOpPred :: new(
552
+ // This one stay in join condition
553
+ ColumnRefPred :: new( 2 ) . into_pred_node( ) ,
554
+ ColumnRefPred :: new( 8 ) . into_pred_node( ) ,
555
+ BinOpType :: Eq ,
556
+ )
557
+ . into_pred_node( ) ,
558
+ ] ,
559
+ ) ;
560
+
561
+ let join = LogicalJoin :: new (
562
+ scan1. into_plan_node ( ) ,
563
+ scan2. into_plan_node ( ) ,
564
+ join_cond. into_pred_node ( ) ,
565
+ super :: JoinType :: LeftOuter ,
566
+ ) ;
567
+
568
+ let plan = test_optimizer. optimize ( join. into_plan_node ( ) ) . unwrap ( ) ;
569
+ println ! ( "{}" , plan. explain_to_string( None ) ) ;
570
+ }
571
+
445
572
#[ test]
446
573
fn push_past_join_conjunction ( ) {
447
574
// Test pushing a complex filter past a join, where one clause can
0 commit comments