Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit cfa595b

Browse files
authored
feat: Eliminate Filter Rule (#54)
1 parent 81dfe73 commit cfa595b

File tree

10 files changed

+146
-10
lines changed

10 files changed

+146
-10
lines changed

optd-core/src/cascades/memo.rs

+33-5
Original file line numberDiff line numberDiff line change
@@ -109,15 +109,37 @@ impl<T: RelNodeTyp> Memo<T> {
109109
ExprId(id)
110110
}
111111

112-
fn merge_group(&mut self, group_a: ReducedGroupId, group_b: ReducedGroupId) -> ReducedGroupId {
112+
fn merge_group_inner(
113+
&mut self,
114+
group_a: ReducedGroupId,
115+
group_b: ReducedGroupId,
116+
) -> ReducedGroupId {
113117
if group_a == group_b {
114118
return group_a;
115119
}
116120
self.merged_groups
117121
.insert(group_a.as_group_id(), group_b.as_group_id());
122+
123+
// Copy all expressions from group a to group b
124+
let group_a_exprs = self.get_all_exprs_in_group(group_a.as_group_id());
125+
for expr_id in group_a_exprs {
126+
let expr_node = self.expr_id_to_expr_node.get(&expr_id).unwrap();
127+
self.add_expr_to_group(expr_id, group_b, expr_node.as_ref().clone());
128+
}
129+
130+
// Remove all expressions from group a (so we don't accidentally access it)
131+
self.clear_exprs_in_group(group_a);
132+
118133
group_b
119134
}
120135

136+
pub fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) -> GroupId {
137+
let group_a_reduced = self.get_reduced_group_id(group_a);
138+
let group_b_reduced = self.get_reduced_group_id(group_b);
139+
self.merge_group_inner(group_a_reduced, group_b_reduced)
140+
.as_group_id()
141+
}
142+
121143
fn get_group_id_of_expr_id(&self, expr_id: ExprId) -> GroupId {
122144
self.expr_id_to_group_id[&expr_id]
123145
}
@@ -136,9 +158,11 @@ impl<T: RelNodeTyp> Memo<T> {
136158
rel_node: RelNodeRef<T>,
137159
add_to_group_id: Option<GroupId>,
138160
) -> (GroupId, ExprId) {
139-
if rel_node.typ.extract_group().is_some() {
140-
unreachable!();
141-
}
161+
let node_current_group = rel_node.typ.extract_group();
162+
if let (Some(grp_a), Some(grp_b)) = (add_to_group_id, node_current_group) {
163+
self.merge_group(grp_a, grp_b);
164+
};
165+
142166
let (group_id, expr_id) = self.add_new_group_expr_inner(
143167
rel_node,
144168
add_to_group_id.map(|x| self.get_reduced_group_id(x)),
@@ -198,6 +222,10 @@ impl<T: RelNodeTyp> Memo<T> {
198222
props
199223
}
200224

225+
fn clear_exprs_in_group(&mut self, group_id: ReducedGroupId) {
226+
self.groups.remove(&group_id);
227+
}
228+
201229
fn add_expr_to_group(
202230
&mut self,
203231
expr_id: ExprId,
@@ -243,7 +271,7 @@ impl<T: RelNodeTyp> Memo<T> {
243271
let group_id = self.get_group_id_of_expr_id(expr_id);
244272
let group_id = self.get_reduced_group_id(group_id);
245273
if let Some(add_to_group_id) = add_to_group_id {
246-
self.merge_group(add_to_group_id, group_id);
274+
self.merge_group_inner(add_to_group_id, group_id);
247275
}
248276
return (group_id, expr_id);
249277
}

optd-core/src/cascades/optimizer.rs

+4
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
290290
self.memo.update_group_info(group_id, group_info)
291291
}
292292

293+
pub(super) fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) {
294+
self.memo.merge_group(group_a, group_b);
295+
}
296+
293297
pub fn get_property_by_group<P: PropertyBuilder<T>>(
294298
&self,
295299
group_id: GroupId,

optd-core/src/cascades/tasks/apply_rule.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -197,8 +197,10 @@ impl<T: RelNodeTyp> Task<T> for ApplyRuleTask {
197197
let applied = rule.apply(optimizer, expr);
198198
for expr in applied {
199199
let RelNode { typ, .. } = &expr;
200-
if typ.extract_group().is_some() {
201-
unreachable!();
200+
if let Some(group_id_2) = typ.extract_group() {
201+
// If this is a group, merge the groups!
202+
optimizer.merge_group(group_id, group_id_2);
203+
continue;
202204
}
203205
let expr_typ = typ.clone();
204206
let (_, expr_id) = optimizer.add_group_expr(expr.into(), Some(group_id));

optd-datafusion-bridge/src/into_optd.rs

+4
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,10 @@ impl OptdPlanContext<'_> {
110110
let x = x.as_ref().unwrap();
111111
Ok(ConstantExpr::decimal(*x as f64).into_expr())
112112
}
113+
ScalarValue::Boolean(x) => {
114+
let x = x.as_ref().unwrap();
115+
Ok(ConstantExpr::bool(*x).into_expr())
116+
}
113117
_ => bail!("{:?}", x),
114118
},
115119
Expr::Alias(x) => self.conv_into_optd_expr(x.expr.as_ref(), context),

optd-datafusion-repr/src/lib.rs

+5-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ use optd_core::cascades::{CascadesOptimizer, GroupId, OptimizerProperties};
88
use plan_nodes::{OptRelNode, OptRelNodeRef, OptRelNodeTyp, PlanNode};
99
use properties::schema::{Catalog, SchemaPropertyBuilder};
1010
use rules::{
11-
EliminateJoinRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, PhysicalConversionRule,
12-
ProjectionPullUpJoin,
11+
EliminateFilterRule, EliminateJoinRule, HashJoinRule, JoinAssocRule, JoinCommuteRule,
12+
PhysicalConversionRule, ProjectionPullUpJoin,
1313
};
1414

1515
pub use adaptive::PhysicalCollector;
@@ -48,6 +48,7 @@ impl DatafusionOptimizer {
4848
rules.push(Arc::new(JoinAssocRule::new()));
4949
rules.push(Arc::new(ProjectionPullUpJoin::new()));
5050
rules.push(Arc::new(EliminateJoinRule::new()));
51+
rules.push(Arc::new(EliminateFilterRule::new()));
5152

5253
let cost_model = AdaptiveCostModel::new(50);
5354
Self {
@@ -72,6 +73,8 @@ impl DatafusionOptimizer {
7273
rules.insert(0, Arc::new(JoinCommuteRule::new()));
7374
rules.insert(1, Arc::new(JoinAssocRule::new()));
7475
rules.insert(2, Arc::new(ProjectionPullUpJoin::new()));
76+
rules.insert(3, Arc::new(EliminateFilterRule::new()));
77+
7578
let cost_model = AdaptiveCostModel::new(1000); // very large decay
7679
let runtime_statistics = cost_model.get_runtime_map();
7780
let optimizer = CascadesOptimizer::new(

optd-datafusion-repr/src/rules.rs

+2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
// mod filter_join;
2+
mod eliminate_filter;
23
mod joins;
34
mod macros;
45
mod physical;
56

67
// pub use filter_join::FilterJoinPullUpRule;
8+
pub use eliminate_filter::EliminateFilterRule;
79
pub use joins::{
810
EliminateJoinRule, HashJoinRule, JoinAssocRule, JoinCommuteRule, ProjectionPullUpJoin,
911
};
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
use std::collections::HashMap;
2+
3+
use optd_core::rules::{Rule, RuleMatcher};
4+
use optd_core::{optimizer::Optimizer, rel_node::RelNode};
5+
6+
use crate::plan_nodes::{ConstantType, LogicalEmptyRelation, OptRelNode, OptRelNodeTyp};
7+
8+
use super::macros::define_rule;
9+
10+
define_rule!(
11+
EliminateFilterRule,
12+
apply_eliminate_filter,
13+
(Filter, child, [cond])
14+
);
15+
16+
/// Transformations:
17+
/// - Filter node w/ false pred -> EmptyRelation
18+
/// - Filter node w/ true pred -> Eliminate from the tree
19+
fn apply_eliminate_filter(
20+
_optimizer: &impl Optimizer<OptRelNodeTyp>,
21+
EliminateFilterRulePicks { child, cond }: EliminateFilterRulePicks,
22+
) -> Vec<RelNode<OptRelNodeTyp>> {
23+
if let OptRelNodeTyp::Constant(ConstantType::Bool) = cond.typ {
24+
if let Some(data) = cond.data {
25+
if data.as_bool() {
26+
// If the condition is true, eliminate the filter node, as it
27+
// will yield everything from below it.
28+
return vec![child];
29+
} else {
30+
// If the condition is false, replace this node with the empty relation,
31+
// since it will never yield tuples.
32+
let node = LogicalEmptyRelation::new(false);
33+
return vec![node.into_rel_node().as_ref().clone()];
34+
}
35+
}
36+
}
37+
vec![]
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
-- (no id or description)
2+
create table t1(t1v1 int, t1v2 int);
3+
create table t2(t2v1 int, t2v3 int);
4+
insert into t1 values (0, 0), (1, 1), (2, 2);
5+
insert into t2 values (0, 200), (1, 201), (2, 202);
6+
7+
/*
8+
3
9+
3
10+
*/
11+
12+
-- Test EliminateFilterRule (false filter to empty relation)
13+
select * from t1 where false;
14+
15+
/*
16+
LogicalProjection { exprs: [ #0, #1 ] }
17+
└── LogicalFilter { cond: false }
18+
└── LogicalScan { table: t1 }
19+
PhysicalProjection { exprs: [ #0, #1 ] }
20+
└── PhysicalEmptyRelation { produce_one_row: false }
21+
*/
22+
23+
-- Test EliminateFilterRule (replace true filter with child)
24+
select * from t1 where true;
25+
26+
/*
27+
LogicalProjection { exprs: [ #0, #1 ] }
28+
└── LogicalFilter { cond: true }
29+
└── LogicalScan { table: t1 }
30+
PhysicalProjection { exprs: [ #0, #1 ] }
31+
└── PhysicalScan { table: t1 }
32+
0 0
33+
1 1
34+
2 2
35+
*/
36+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
- sql: |
2+
create table t1(t1v1 int, t1v2 int);
3+
create table t2(t2v1 int, t2v3 int);
4+
insert into t1 values (0, 0), (1, 1), (2, 2);
5+
insert into t2 values (0, 200), (1, 201), (2, 202);
6+
tasks:
7+
- execute
8+
- sql: |
9+
select * from t1 where false;
10+
desc: Test EliminateFilterRule (false filter to empty relation)
11+
tasks:
12+
- explain:logical_optd,physical_optd
13+
- execute
14+
- sql: |
15+
select * from t1 where true;
16+
desc: Test EliminateFilterRule (replace true filter with child)
17+
tasks:
18+
- explain:logical_optd,physical_optd
19+
- execute

optd-sqlplannertest/tests/empty_relation.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@
1818
desc: Test whether the optimizer eliminates join to empty relation
1919
tasks:
2020
- explain:logical_optd,physical_optd
21-
- execute
21+
- execute

0 commit comments

Comments
 (0)