Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit dbc99a0

Browse files
committed
Add JoinSplitFilter rules
Signed-off-by: Yuchen Liang <[email protected]>
1 parent c9f59b0 commit dbc99a0

File tree

1 file changed

+128
-1
lines changed

1 file changed

+128
-1
lines changed

optd-datafusion-repr/src/rules/filter_pushdown.rs

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ use optd_core::optimizer::Optimizer;
2424
use optd_core::rules::{Rule, RuleMatcher};
2525

2626
use super::filter::simplify_log_expr;
27-
use super::macros::define_rule;
27+
use super::macros::{define_rule, define_rule_discriminant};
2828
use crate::plan_nodes::{
2929
ArcDfPlanNode, ArcDfPredNode, ColumnRefPred, DfNodeType, DfPredType, DfReprPlanNode,
3030
DfReprPredNode, JoinType, ListPred, LogOpPred, LogOpType, LogicalAgg, LogicalFilter,
@@ -160,6 +160,87 @@ fn apply_filter_merge(
160160
vec![new_filter.into_plan_node().into()]
161161
}
162162

163+
// Rule to split predicates in a join condition into those that can be pushed down as filters.
164+
define_rule!(
165+
InnerJoinSplitFilterRule,
166+
apply_join_split_filter,
167+
(Join(JoinType::Inner), child_a, child_b)
168+
);
169+
170+
define_rule!(
171+
LeftOuterJoinSplitFilterRule,
172+
apply_join_split_filter,
173+
(Join(JoinType::LeftOuter), child_a, child_b)
174+
);
175+
176+
fn apply_join_split_filter(
177+
optimizer: &impl Optimizer<DfNodeType>,
178+
binding: ArcDfPlanNode,
179+
) -> Vec<PlanNodeOrGroup<DfNodeType>> {
180+
println!("Applying JoinSplitFilterRule");
181+
let join = LogicalJoin::from_plan_node(binding).unwrap();
182+
let left_child = join.left();
183+
let right_child = join.right();
184+
let join_cond = join.cond();
185+
let join_typ = join.join_type();
186+
187+
let left_schema_size = optimizer.get_schema_of(left_child.clone()).len();
188+
let right_schema_size = optimizer.get_schema_of(right_child.clone()).len();
189+
190+
// Conditions that only involve the left relation.
191+
let mut left_conds = vec![];
192+
// Conditions that only involve the right relation.
193+
let mut right_conds = vec![];
194+
// Conditions that involve both relations.
195+
let mut keep_conds = vec![];
196+
197+
let categorization_fn = |expr: ArcDfPredNode, children: &[ArcDfPredNode]| {
198+
let location = determine_join_cond_dep(children, left_schema_size, right_schema_size);
199+
match location {
200+
JoinCondDependency::Left => left_conds.push(expr),
201+
JoinCondDependency::Right => right_conds.push(
202+
expr.rewrite_column_refs(|idx| {
203+
Some(LogicalJoin::map_through_join(
204+
idx,
205+
left_schema_size,
206+
right_schema_size,
207+
))
208+
})
209+
.unwrap(),
210+
),
211+
JoinCondDependency::Both => keep_conds.push(expr),
212+
JoinCondDependency::None => {
213+
unreachable!("join condition should always involve at least one relation");
214+
}
215+
}
216+
};
217+
categorize_conds(categorization_fn, join_cond);
218+
219+
let new_left = if !left_conds.is_empty() {
220+
let new_filter_node =
221+
LogicalFilter::new_unchecked(left_child, and_expr_list_to_expr(left_conds));
222+
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
223+
} else {
224+
left_child
225+
};
226+
227+
let new_right = if !right_conds.is_empty() {
228+
let new_filter_node =
229+
LogicalFilter::new_unchecked(right_child, and_expr_list_to_expr(right_conds));
230+
PlanNodeOrGroup::PlanNode(new_filter_node.into_plan_node())
231+
} else {
232+
right_child
233+
};
234+
235+
let new_join = LogicalJoin::new_unchecked(
236+
new_left,
237+
new_right,
238+
and_expr_list_to_expr(keep_conds),
239+
*join_typ,
240+
);
241+
242+
vec![new_join.into_plan_node().into()]
243+
}
163244
define_rule!(
164245
FilterInnerJoinTransposeRule,
165246
apply_filter_inner_join_transpose,
@@ -442,6 +523,52 @@ mod tests {
442523
assert_eq!(col_4.value().as_i32(), 1);
443524
}
444525

526+
#[test]
527+
fn join_split_filter() {
528+
let mut test_optimizer = new_test_optimizer(Arc::new(LeftOuterJoinSplitFilterRule::new()));
529+
530+
let scan1 = LogicalScan::new("customer".into());
531+
532+
let scan2 = LogicalScan::new("orders".into());
533+
534+
let join_cond = LogOpPred::new(
535+
LogOpType::And,
536+
vec![
537+
BinOpPred::new(
538+
// This one should be pushed to the left child
539+
ColumnRefPred::new(0).into_pred_node(),
540+
ConstantPred::int32(5).into_pred_node(),
541+
BinOpType::Eq,
542+
)
543+
.into_pred_node(),
544+
BinOpPred::new(
545+
// This one should be pushed to the right child
546+
ColumnRefPred::new(11).into_pred_node(),
547+
ConstantPred::int32(6).into_pred_node(),
548+
BinOpType::Eq,
549+
)
550+
.into_pred_node(),
551+
BinOpPred::new(
552+
// This one stay in join condition
553+
ColumnRefPred::new(2).into_pred_node(),
554+
ColumnRefPred::new(8).into_pred_node(),
555+
BinOpType::Eq,
556+
)
557+
.into_pred_node(),
558+
],
559+
);
560+
561+
let join = LogicalJoin::new(
562+
scan1.into_plan_node(),
563+
scan2.into_plan_node(),
564+
join_cond.into_pred_node(),
565+
super::JoinType::LeftOuter,
566+
);
567+
568+
let plan = test_optimizer.optimize(join.into_plan_node()).unwrap();
569+
println!("{}", plan.explain_to_string(None));
570+
}
571+
445572
#[test]
446573
fn push_past_join_conjunction() {
447574
// Test pushing a complex filter past a join, where one clause can

0 commit comments

Comments
 (0)