Skip to content
This repository has been archived by the owner on Jan 7, 2025. It is now read-only.

Commit

Permalink
Add JoinSplitFilter rules
Browse files Browse the repository at this point in the history
Signed-off-by: Yuchen Liang <[email protected]>
  • Loading branch information
yliang412 committed Jan 6, 2025
1 parent c9f59b0 commit dbc99a0
Showing 1 changed file with 128 additions and 1 deletion.
129 changes: 128 additions & 1 deletion optd-datafusion-repr/src/rules/filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ use optd_core::optimizer::Optimizer;
use optd_core::rules::{Rule, RuleMatcher};

use super::filter::simplify_log_expr;
use super::macros::define_rule;
use super::macros::{define_rule, define_rule_discriminant};
use crate::plan_nodes::{
ArcDfPlanNode, ArcDfPredNode, ColumnRefPred, DfNodeType, DfPredType, DfReprPlanNode,
DfReprPredNode, JoinType, ListPred, LogOpPred, LogOpType, LogicalAgg, LogicalFilter,
Expand Down Expand Up @@ -160,6 +160,87 @@ fn apply_filter_merge(
vec![new_filter.into_plan_node().into()]
}

// Rule to split predicates in a join condition into those that can be pushed down as filters.
define_rule!(
InnerJoinSplitFilterRule,
apply_join_split_filter,
(Join(JoinType::Inner), child_a, child_b)
);

define_rule!(
LeftOuterJoinSplitFilterRule,
apply_join_split_filter,
(Join(JoinType::LeftOuter), child_a, child_b)
);

fn apply_join_split_filter(
optimizer: &impl Optimizer<DfNodeType>,
binding: ArcDfPlanNode,
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
println!("Applying JoinSplitFilterRule");
let join = LogicalJoin::from_plan_node(binding).unwrap();
let left_child = join.left();
let right_child = join.right();
let join_cond = join.cond();
let join_typ = join.join_type();

let left_schema_size = optimizer.get_schema_of(left_child.clone()).len();
let right_schema_size = optimizer.get_schema_of(right_child.clone()).len();

// Conditions that only involve the left relation.
let mut left_conds = vec![];
// Conditions that only involve the right relation.
let mut right_conds = vec![];
// Conditions that involve both relations.
let mut keep_conds = vec![];

let categorization_fn = |expr: ArcDfPredNode, children: &[ArcDfPredNode]| {
let location = determine_join_cond_dep(children, left_schema_size, right_schema_size);
match location {
JoinCondDependency::Left => left_conds.push(expr),
JoinCondDependency::Right => right_conds.push(
expr.rewrite_column_refs(|idx| {
Some(LogicalJoin::map_through_join(
idx,
left_schema_size,
right_schema_size,
))
})
.unwrap(),
),
JoinCondDependency::Both => keep_conds.push(expr),
JoinCondDependency::None => {
unreachable!("join condition should always involve at least one relation");
}
}
};
categorize_conds(categorization_fn, join_cond);

let new_left = if !left_conds.is_empty() {
let new_filter_node =
LogicalFilter::new_unchecked(left_child, and_expr_list_to_expr(left_conds));
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
} else {
left_child
};

let new_right = if !right_conds.is_empty() {
let new_filter_node =
LogicalFilter::new_unchecked(right_child, and_expr_list_to_expr(right_conds));
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
} else {
right_child
};

let new_join = LogicalJoin::new_unchecked(
new_left,
new_right,
and_expr_list_to_expr(keep_conds),
*join_typ,
);

vec![new_join.into_plan_node().into()]
}
define_rule!(
FilterInnerJoinTransposeRule,
apply_filter_inner_join_transpose,
Expand Down Expand Up @@ -442,6 +523,52 @@ mod tests {
assert_eq!(col_4.value().as_i32(), 1);
}

#[test]
fn join_split_filter() {
let mut test_optimizer = new_test_optimizer(Arc::new(LeftOuterJoinSplitFilterRule::new()));

let scan1 = LogicalScan::new("customer".into());

let scan2 = LogicalScan::new("orders".into());

let join_cond = LogOpPred::new(
LogOpType::And,
vec![
BinOpPred::new(
// This one should be pushed to the left child
ColumnRefPred::new(0).into_pred_node(),
ConstantPred::int32(5).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
BinOpPred::new(
// This one should be pushed to the right child
ColumnRefPred::new(11).into_pred_node(),
ConstantPred::int32(6).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
BinOpPred::new(
// This one stay in join condition
ColumnRefPred::new(2).into_pred_node(),
ColumnRefPred::new(8).into_pred_node(),
BinOpType::Eq,
)
.into_pred_node(),
],
);

let join = LogicalJoin::new(
scan1.into_plan_node(),
scan2.into_plan_node(),
join_cond.into_pred_node(),
super::JoinType::LeftOuter,
);

let plan = test_optimizer.optimize(join.into_plan_node()).unwrap();
println!("{}", plan.explain_to_string(None));
}

#[test]
fn push_past_join_conjunction() {
// Test pushing a complex filter past a join, where one clause can
Expand Down

0 comments on commit dbc99a0

Please sign in to comment.