From dbc99a058b25b85e865b2fc2946d8ebd0072d74e Mon Sep 17 00:00:00 2001 From: Yuchen Liang Date: Mon, 6 Jan 2025 12:37:03 -0500 Subject: [PATCH] Add JoinSplitFilter rules Signed-off-by: Yuchen Liang --- .../src/rules/filter_pushdown.rs | 129 +++++++++++++++++- 1 file changed, 128 insertions(+), 1 deletion(-) diff --git a/optd-datafusion-repr/src/rules/filter_pushdown.rs b/optd-datafusion-repr/src/rules/filter_pushdown.rs index dd46734b..1abd7a6d 100644 --- a/optd-datafusion-repr/src/rules/filter_pushdown.rs +++ b/optd-datafusion-repr/src/rules/filter_pushdown.rs @@ -24,7 +24,7 @@ use optd_core::optimizer::Optimizer; use optd_core::rules::{Rule, RuleMatcher}; use super::filter::simplify_log_expr; -use super::macros::define_rule; +use super::macros::{define_rule, define_rule_discriminant}; use crate::plan_nodes::{ ArcDfPlanNode, ArcDfPredNode, ColumnRefPred, DfNodeType, DfPredType, DfReprPlanNode, DfReprPredNode, JoinType, ListPred, LogOpPred, LogOpType, LogicalAgg, LogicalFilter, @@ -160,6 +160,87 @@ fn apply_filter_merge( vec![new_filter.into_plan_node().into()] } +// Rule to split predicates in a join condition into those that can be pushed down as filters. +define_rule!( + InnerJoinSplitFilterRule, + apply_join_split_filter, + (Join(JoinType::Inner), child_a, child_b) +); + +define_rule!( + LeftOuterJoinSplitFilterRule, + apply_join_split_filter, + (Join(JoinType::LeftOuter), child_a, child_b) +); + +fn apply_join_split_filter( + optimizer: &impl Optimizer, + binding: ArcDfPlanNode, +) -> Vec> { + println!("Applying JoinSplitFilterRule"); + let join = LogicalJoin::from_plan_node(binding).unwrap(); + let left_child = join.left(); + let right_child = join.right(); + let join_cond = join.cond(); + let join_typ = join.join_type(); + + let left_schema_size = optimizer.get_schema_of(left_child.clone()).len(); + let right_schema_size = optimizer.get_schema_of(right_child.clone()).len(); + + // Conditions that only involve the left relation. + let mut left_conds = vec![]; + // Conditions that only involve the right relation. + let mut right_conds = vec![]; + // Conditions that involve both relations. + let mut keep_conds = vec![]; + + let categorization_fn = |expr: ArcDfPredNode, children: &[ArcDfPredNode]| { + let location = determine_join_cond_dep(children, left_schema_size, right_schema_size); + match location { + JoinCondDependency::Left => left_conds.push(expr), + JoinCondDependency::Right => right_conds.push( + expr.rewrite_column_refs(|idx| { + Some(LogicalJoin::map_through_join( + idx, + left_schema_size, + right_schema_size, + )) + }) + .unwrap(), + ), + JoinCondDependency::Both => keep_conds.push(expr), + JoinCondDependency::None => { + unreachable!("join condition should always involve at least one relation"); + } + } + }; + categorize_conds(categorization_fn, join_cond); + + let new_left = if !left_conds.is_empty() { + let new_filter_node = + LogicalFilter::new_unchecked(left_child, and_expr_list_to_expr(left_conds)); + PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node()) + } else { + left_child + }; + + let new_right = if !right_conds.is_empty() { + let new_filter_node = + LogicalFilter::new_unchecked(right_child, and_expr_list_to_expr(right_conds)); + PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node()) + } else { + right_child + }; + + let new_join = LogicalJoin::new_unchecked( + new_left, + new_right, + and_expr_list_to_expr(keep_conds), + *join_typ, + ); + + vec![new_join.into_plan_node().into()] +} define_rule!( FilterInnerJoinTransposeRule, apply_filter_inner_join_transpose, @@ -442,6 +523,52 @@ mod tests { assert_eq!(col_4.value().as_i32(), 1); } + #[test] + fn join_split_filter() { + let mut test_optimizer = new_test_optimizer(Arc::new(LeftOuterJoinSplitFilterRule::new())); + + let scan1 = LogicalScan::new("customer".into()); + + let scan2 = LogicalScan::new("orders".into()); + + let join_cond = LogOpPred::new( + LogOpType::And, + vec![ + BinOpPred::new( + // This one should be pushed to the left child + ColumnRefPred::new(0).into_pred_node(), + ConstantPred::int32(5).into_pred_node(), + BinOpType::Eq, + ) + .into_pred_node(), + BinOpPred::new( + // This one should be pushed to the right child + ColumnRefPred::new(11).into_pred_node(), + ConstantPred::int32(6).into_pred_node(), + BinOpType::Eq, + ) + .into_pred_node(), + BinOpPred::new( + // This one stay in join condition + ColumnRefPred::new(2).into_pred_node(), + ColumnRefPred::new(8).into_pred_node(), + BinOpType::Eq, + ) + .into_pred_node(), + ], + ); + + let join = LogicalJoin::new( + scan1.into_plan_node(), + scan2.into_plan_node(), + join_cond.into_pred_node(), + super::JoinType::LeftOuter, + ); + + let plan = test_optimizer.optimize(join.into_plan_node()).unwrap(); + println!("{}", plan.explain_to_string(None)); + } + #[test] fn push_past_join_conjunction() { // Test pushing a complex filter past a join, where one clause can