@@ -24,7 +24,7 @@ use optd_core::optimizer::Optimizer;
2424use  optd_core:: rules:: { Rule ,  RuleMatcher } ; 
2525
2626use  super :: filter:: simplify_log_expr; 
27- use  super :: macros:: define_rule; 
27+ use  super :: macros:: { define_rule,  define_rule_discriminant } ; 
2828use  crate :: plan_nodes:: { 
2929    ArcDfPlanNode ,  ArcDfPredNode ,  ColumnRefPred ,  DfNodeType ,  DfPredType ,  DfReprPlanNode , 
3030    DfReprPredNode ,  JoinType ,  ListPred ,  LogOpPred ,  LogOpType ,  LogicalAgg ,  LogicalFilter , 
@@ -160,6 +160,87 @@ fn apply_filter_merge(
160160    vec ! [ new_filter. into_plan_node( ) . into( ) ] 
161161} 
162162
163+ // Rule to split predicates in a join condition into those that can be pushed down as filters. 
164+ define_rule ! ( 
165+     InnerJoinSplitFilterRule , 
166+     apply_join_split_filter, 
167+     ( Join ( JoinType :: Inner ) ,  child_a,  child_b) 
168+ ) ; 
169+ 
170+ define_rule ! ( 
171+     LeftOuterJoinSplitFilterRule , 
172+     apply_join_split_filter, 
173+     ( Join ( JoinType :: LeftOuter ) ,  child_a,  child_b) 
174+ ) ; 
175+ 
176+ fn  apply_join_split_filter ( 
177+     optimizer :  & impl  Optimizer < DfNodeType > , 
178+     binding :  ArcDfPlanNode , 
179+ )  -> Vec < PlanNodeOrGroup < DfNodeType > >  { 
180+     println ! ( "Applying JoinSplitFilterRule" ) ; 
181+     let  join = LogicalJoin :: from_plan_node ( binding) . unwrap ( ) ; 
182+     let  left_child = join. left ( ) ; 
183+     let  right_child = join. right ( ) ; 
184+     let  join_cond = join. cond ( ) ; 
185+     let  join_typ = join. join_type ( ) ; 
186+ 
187+     let  left_schema_size = optimizer. get_schema_of ( left_child. clone ( ) ) . len ( ) ; 
188+     let  right_schema_size = optimizer. get_schema_of ( right_child. clone ( ) ) . len ( ) ; 
189+ 
190+     // Conditions that only involve the left relation. 
191+     let  mut  left_conds = vec ! [ ] ; 
192+     // Conditions that only involve the right relation. 
193+     let  mut  right_conds = vec ! [ ] ; 
194+     // Conditions that involve both relations. 
195+     let  mut  keep_conds = vec ! [ ] ; 
196+ 
197+     let  categorization_fn = |expr :  ArcDfPredNode ,  children :  & [ ArcDfPredNode ] | { 
198+         let  location = determine_join_cond_dep ( children,  left_schema_size,  right_schema_size) ; 
199+         match  location { 
200+             JoinCondDependency :: Left  => left_conds. push ( expr) , 
201+             JoinCondDependency :: Right  => right_conds. push ( 
202+                 expr. rewrite_column_refs ( |idx| { 
203+                     Some ( LogicalJoin :: map_through_join ( 
204+                         idx, 
205+                         left_schema_size, 
206+                         right_schema_size, 
207+                     ) ) 
208+                 } ) 
209+                 . unwrap ( ) , 
210+             ) , 
211+             JoinCondDependency :: Both  => keep_conds. push ( expr) , 
212+             JoinCondDependency :: None  => { 
213+                 unreachable ! ( "join condition should always involve at least one relation" ) ; 
214+             } 
215+         } 
216+     } ; 
217+     categorize_conds ( categorization_fn,  join_cond) ; 
218+ 
219+     let  new_left = if  !left_conds. is_empty ( )  { 
220+         let  new_filter_node =
221+             LogicalFilter :: new_unchecked ( left_child,  and_expr_list_to_expr ( left_conds) ) ; 
222+         PlanNodeOrGroup :: PlanNode ( new_filter_node. into_plan_node ( ) ) 
223+     }  else  { 
224+         left_child
225+     } ; 
226+ 
227+     let  new_right = if  !right_conds. is_empty ( )  { 
228+         let  new_filter_node =
229+             LogicalFilter :: new_unchecked ( right_child,  and_expr_list_to_expr ( right_conds) ) ; 
230+         PlanNodeOrGroup :: PlanNode ( new_filter_node. into_plan_node ( ) ) 
231+     }  else  { 
232+         right_child
233+     } ; 
234+ 
235+     let  new_join = LogicalJoin :: new_unchecked ( 
236+         new_left, 
237+         new_right, 
238+         and_expr_list_to_expr ( keep_conds) , 
239+         * join_typ, 
240+     ) ; 
241+ 
242+     vec ! [ new_join. into_plan_node( ) . into( ) ] 
243+ } 
163244define_rule ! ( 
164245    FilterInnerJoinTransposeRule , 
165246    apply_filter_inner_join_transpose, 
@@ -442,6 +523,52 @@ mod tests {
442523        assert_eq ! ( col_4. value( ) . as_i32( ) ,  1 ) ; 
443524    } 
444525
526+     #[ test]  
527+     fn  join_split_filter ( )  { 
528+         let  mut  test_optimizer = new_test_optimizer ( Arc :: new ( LeftOuterJoinSplitFilterRule :: new ( ) ) ) ; 
529+ 
530+         let  scan1 = LogicalScan :: new ( "customer" . into ( ) ) ; 
531+ 
532+         let  scan2 = LogicalScan :: new ( "orders" . into ( ) ) ; 
533+ 
534+         let  join_cond = LogOpPred :: new ( 
535+             LogOpType :: And , 
536+             vec ! [ 
537+                 BinOpPred :: new( 
538+                     // This one should be pushed to the left child 
539+                     ColumnRefPred :: new( 0 ) . into_pred_node( ) , 
540+                     ConstantPred :: int32( 5 ) . into_pred_node( ) , 
541+                     BinOpType :: Eq , 
542+                 ) 
543+                 . into_pred_node( ) , 
544+                 BinOpPred :: new( 
545+                     // This one should be pushed to the right child 
546+                     ColumnRefPred :: new( 11 ) . into_pred_node( ) , 
547+                     ConstantPred :: int32( 6 ) . into_pred_node( ) , 
548+                     BinOpType :: Eq , 
549+                 ) 
550+                 . into_pred_node( ) , 
551+                 BinOpPred :: new( 
552+                     // This one stay in join condition 
553+                     ColumnRefPred :: new( 2 ) . into_pred_node( ) , 
554+                     ColumnRefPred :: new( 8 ) . into_pred_node( ) , 
555+                     BinOpType :: Eq , 
556+                 ) 
557+                 . into_pred_node( ) , 
558+             ] , 
559+         ) ; 
560+ 
561+         let  join = LogicalJoin :: new ( 
562+             scan1. into_plan_node ( ) , 
563+             scan2. into_plan_node ( ) , 
564+             join_cond. into_pred_node ( ) , 
565+             super :: JoinType :: LeftOuter , 
566+         ) ; 
567+ 
568+         let  plan = test_optimizer. optimize ( join. into_plan_node ( ) ) . unwrap ( ) ; 
569+         println ! ( "{}" ,  plan. explain_to_string( None ) ) ; 
570+     } 
571+ 
445572    #[ test]  
446573    fn  push_past_join_conjunction ( )  { 
447574        // Test pushing a complex filter past a join, where one clause can 
0 commit comments