Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 7b530d1

Browse files
authored
refactor(core): rewrite memo table + enable filter-project rules in df-repr (#203)
* rewrite the memo table to fully resolve the group merging issues; add test cases; drop a few stale APIs from the memo table * enable the filter-project set of rules and it won't run out of budget now --------- Signed-off-by: Alex Chi <[email protected]>
1 parent 1000e13 commit 7b530d1

File tree

14 files changed

+700
-494
lines changed

14 files changed

+700
-494
lines changed

optd-core/src/cascades/memo.rs

+471-294
Large diffs are not rendered by default.

optd-core/src/cascades/optimizer.rs

+30-63
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ pub struct OptimizerContext {
3232

3333
#[derive(Default, Clone, Debug)]
3434
pub struct OptimizerProperties {
35-
pub partial_explore_temporarily_disabled: bool,
35+
pub panic_on_budget: bool,
3636
/// If the number of rules applied exceeds this number, we stop applying logical rules.
3737
pub partial_explore_iter: Option<usize>,
3838
/// Plan space can be expanded by this number of times before we stop applying logical rules.
@@ -88,12 +88,8 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
8888
Self::new_with_prop(rules, cost, property_builders, Default::default())
8989
}
9090

91-
pub fn disable_explore_limit(&mut self) {
92-
self.prop.partial_explore_temporarily_disabled = true;
93-
}
94-
95-
pub fn enable_explore_limit(&mut self) {
96-
self.prop.partial_explore_temporarily_disabled = false;
91+
pub fn panic_on_explore_limit(&mut self, enabled: bool) {
92+
self.prop.panic_on_budget = enabled;
9793
}
9894

9995
pub fn new_with_prop(
@@ -190,12 +186,13 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
190186
for expr_id in self.memo.get_all_exprs_in_group(group_id) {
191187
let memo_node = self.memo.get_expr_memoed(expr_id);
192188
println!(" expr_id={} | {}", expr_id, memo_node);
193-
let bindings = self
194-
.memo
195-
.get_all_expr_bindings(expr_id, false, true, Some(1));
196-
for binding in bindings {
197-
println!(" {}", binding);
198-
}
189+
// We removed get all bindings functionality
190+
// let bindings = self
191+
// .memo
192+
// .get_all_expr_bindings(expr_id, false, true, Some(1));
193+
// for binding in bindings {
194+
// println!(" {}", binding);
195+
// }
199196
}
200197
}
201198
}
@@ -214,7 +211,7 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
214211

215212
/// Optimize a `RelNode`.
216213
pub fn step_optimize_rel(&mut self, root_rel: RelNodeRef<T>) -> Result<GroupId> {
217-
let (group_id, _) = self.add_group_expr(root_rel, None);
214+
let (group_id, _) = self.add_new_expr(root_rel);
218215
self.fire_optimize_tasks(group_id)?;
219216
Ok(group_id)
220217
}
@@ -240,7 +237,7 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
240237
let new_tasks = task.execute(self)?;
241238
self.tasks.extend(new_tasks);
242239
iter += 1;
243-
if !self.ctx.budget_used && !self.prop.partial_explore_temporarily_disabled {
240+
if !self.ctx.budget_used {
244241
let plan_space = self.memo.compute_plan_space();
245242
if let Some(partial_explore_space) = self.prop.partial_explore_space {
246243
if plan_space - plan_space_begin > partial_explore_space {
@@ -249,6 +246,9 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
249246
plan_space
250247
);
251248
self.ctx.budget_used = true;
249+
if self.prop.panic_on_budget {
250+
panic!("plan space size budget used");
251+
}
252252
}
253253
} else if let Some(partial_explore_iter) = self.prop.partial_explore_iter {
254254
if iter >= partial_explore_iter {
@@ -257,15 +257,21 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
257257
plan_space
258258
);
259259
self.ctx.budget_used = true;
260+
if self.prop.panic_on_budget {
261+
panic!("plan space size budget used");
262+
}
260263
}
261264
}
262265
}
263266
}
267+
// if self.ctx.budget_used {
268+
// self.dump(None);
269+
// }
264270
Ok(())
265271
}
266272

267273
fn optimize_inner(&mut self, root_rel: RelNodeRef<T>) -> Result<RelNodeRef<T>> {
268-
let (group_id, _) = self.add_group_expr(root_rel, None);
274+
let (group_id, _) = self.add_new_expr(root_rel);
269275
self.fire_optimize_tasks(group_id)?;
270276
self.memo.get_best_group_binding(group_id, &mut None)
271277
}
@@ -286,37 +292,16 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
286292
self.memo.get_expr_info(expr)
287293
}
288294

289-
pub(super) fn add_group_expr(
290-
&mut self,
291-
expr: RelNodeRef<T>,
292-
group_id: Option<GroupId>,
293-
) -> (GroupId, ExprId) {
294-
self.memo.add_new_group_expr(expr, group_id)
295+
pub fn add_new_expr(&mut self, rel_node: RelNodeRef<T>) -> (GroupId, ExprId) {
296+
self.memo.add_new_expr(rel_node)
295297
}
296298

297-
#[allow(dead_code)]
298-
pub(super) fn replace_group_expr(
299+
pub fn add_expr_to_group(
299300
&mut self,
300-
expr: RelNodeRef<T>,
301+
rel_node: RelNodeRef<T>,
301302
group_id: GroupId,
302-
expr_id: ExprId,
303-
) {
304-
let replaced = self.memo.replace_group_expr(expr_id, group_id, expr);
305-
if replaced {
306-
// the old expr is replaced, so we clear the fired rules for old expr
307-
self.fired_rules.entry(expr_id).or_default().clear();
308-
return;
309-
}
310-
311-
// We can mark the expr as a deadend
312-
// However, even some of the exprs cannot be the winner for the group
313-
// We still need the physical form of those expr to start the optimizeInput task
314-
// So we don't mark the impl rules as fired
315-
for i in 0..self.rules.len() {
316-
if !self.rules[i].rule().is_impl_rule() {
317-
self.fired_rules.entry(expr_id).or_default().insert(i);
318-
}
319-
}
303+
) -> Option<ExprId> {
304+
self.memo.add_expr_to_group(rel_node, group_id)
320305
}
321306

322307
pub(super) fn get_group_info(&self, group_id: GroupId) -> GroupInfo {
@@ -327,10 +312,6 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
327312
self.memo.update_group_info(group_id, group_info)
328313
}
329314

330-
pub(super) fn merge_group(&mut self, group_a: GroupId, group_b: GroupId) {
331-
self.memo.merge_group(group_a, group_b);
332-
}
333-
334315
/// Get the properties of a Cascades group
335316
/// P is the type of the property you expect
336317
/// idx is the idx of the property you want. The order of properties is defined
@@ -354,22 +335,8 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
354335
self.memo.get_expr_memoed(expr_id)
355336
}
356337

357-
pub(super) fn get_all_expr_bindings(
358-
&self,
359-
expr_id: ExprId,
360-
level: Option<usize>,
361-
) -> Vec<RelNodeRef<T>> {
362-
self.memo
363-
.get_all_expr_bindings(expr_id, false, false, level)
364-
}
365-
366-
pub fn get_all_group_bindings(
367-
&self,
368-
group_id: GroupId,
369-
physical_only: bool,
370-
) -> Vec<RelNodeRef<T>> {
371-
self.memo
372-
.get_all_group_bindings(group_id, physical_only, true, Some(10))
338+
pub fn get_predicate_binding(&self, group_id: GroupId) -> Option<RelNodeRef<T>> {
339+
self.memo.get_predicate_binding(group_id)
373340
}
374341

375342
pub(super) fn is_group_explored(&self, group_id: GroupId) -> bool {

optd-core/src/cascades/tasks/apply_rule.rs

+18-22
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,10 @@ fn match_node<T: RelNodeTyp>(
6262
RuleMatcher::PickOne { pick_to, expand } => {
6363
let group_id = node.children[idx];
6464
let node = if *expand {
65-
let mut exprs = optimizer.get_all_exprs_in_group(group_id);
66-
assert_eq!(exprs.len(), 1, "can only expand expression");
67-
let expr = exprs.remove(0);
68-
let mut bindings = optimizer.get_all_expr_bindings(expr, None);
69-
assert_eq!(bindings.len(), 1, "can only expand expression");
70-
bindings.remove(0).as_ref().clone()
65+
let binding = optimizer
66+
.get_predicate_binding(group_id)
67+
.expect("empty group, what's going wrong?");
68+
binding.as_ref().clone()
7169
} else {
7270
RelNode::new_group(group_id)
7371
};
@@ -205,23 +203,21 @@ impl<T: RelNodeTyp> Task<T> for ApplyRuleTask {
205203

206204
for expr in applied {
207205
trace!(event = "after_apply_rule", task = "apply_rule", binding=%expr);
208-
let RelNode { typ, .. } = &expr;
209-
if let Some(group_id_2) = typ.extract_group() {
210-
// If this is a group, merge the groups!
211-
optimizer.merge_group(group_id, group_id_2);
212-
continue;
213-
}
214-
let expr_typ = typ.clone();
215-
let (_, expr_id) = optimizer.add_group_expr(expr.into(), Some(group_id));
216-
trace!(event = "apply_rule", expr_id = %self.expr_id, rule_id = %self.rule_id, new_expr_id = %expr_id);
217-
if expr_typ.is_logical() {
218-
tasks.push(
219-
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring))
220-
as Box<dyn Task<T>>,
221-
);
206+
let expr_typ = expr.typ.clone();
207+
if let Some(expr_id) = optimizer.add_expr_to_group(expr.into(), group_id) {
208+
if expr_typ.is_logical() {
209+
tasks.push(
210+
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring))
211+
as Box<dyn Task<T>>,
212+
);
213+
} else {
214+
tasks
215+
.push(Box::new(OptimizeInputsTask::new(expr_id, true))
216+
as Box<dyn Task<T>>);
217+
}
218+
trace!(event = "apply_rule", expr_id = %self.expr_id, rule_id = %self.rule_id, new_expr_id = %expr_id);
222219
} else {
223-
tasks
224-
.push(Box::new(OptimizeInputsTask::new(expr_id, true)) as Box<dyn Task<T>>);
220+
trace!(event = "apply_rule", expr_id = %self.expr_id, rule_id = %self.rule_id, "triggered group merge");
225221
}
226222
}
227223
}

optd-datafusion-bridge/src/lib.rs

+42-40
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ use datafusion::{
1616
physical_plan::{displayable, explain::ExplainExec, ExecutionPlan},
1717
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
1818
};
19-
use itertools::Itertools;
2019
use optd_datafusion_repr::{
2120
plan_nodes::{
2221
ConstantType, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PhysicalHashJoin,
@@ -26,7 +25,7 @@ use optd_datafusion_repr::{
2625
DatafusionOptimizer,
2726
};
2827
use std::{
29-
collections::{BTreeSet, HashMap},
28+
collections::HashMap,
3029
sync::{Arc, Mutex},
3130
};
3231

@@ -97,6 +96,7 @@ enum JoinOrder {
9796
NestedLoopJoin(Box<Self>, Box<Self>),
9897
}
9998

99+
#[allow(dead_code)]
100100
impl JoinOrder {
101101
pub fn conv_into_logical_join_order(&self) -> LogicalJoinOrder {
102102
match self {
@@ -113,12 +113,14 @@ impl JoinOrder {
113113
}
114114
}
115115

116+
#[allow(unused)]
116117
#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
117118
enum LogicalJoinOrder {
118119
Table(String),
119120
Join(Box<Self>, Box<Self>),
120121
}
121122

123+
#[allow(dead_code)]
122124
fn get_join_order(rel_node: OptRelNodeRef) -> Option<JoinOrder> {
123125
match rel_node.typ {
124126
OptRelNodeTyp::PhysicalHashJoin(_) => {
@@ -245,7 +247,7 @@ impl OptdQueryPlanner {
245247
}
246248
}
247249

248-
let (group_id, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;
250+
let (_, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;
249251

250252
if let Some(explains) = &mut explains {
251253
explains.push(StringifiedPlan::new(
@@ -257,44 +259,44 @@ impl OptdQueryPlanner {
257259
.explain_to_string(if verbose { Some(&meta) } else { None }),
258260
));
259261

260-
const ENABLE_JOIN_ORDER: bool = false;
262+
// const ENABLE_JOIN_ORDER: bool = false;
261263

262-
if ENABLE_JOIN_ORDER {
263-
let join_order = get_join_order(optimized_rel.clone());
264-
explains.push(StringifiedPlan::new(
265-
PlanType::OptimizedPhysicalPlan {
266-
optimizer_name: "optd-join-order".to_string(),
267-
},
268-
if let Some(join_order) = join_order {
269-
join_order.to_string()
270-
} else {
271-
"None".to_string()
272-
},
273-
));
274-
let bindings = optimizer
275-
.optd_cascades_optimizer()
276-
.get_all_group_bindings(group_id, true);
277-
let mut join_orders = BTreeSet::new();
278-
let mut logical_join_orders = BTreeSet::new();
279-
for binding in bindings {
280-
if let Some(join_order) = get_join_order(binding) {
281-
logical_join_orders.insert(join_order.conv_into_logical_join_order());
282-
join_orders.insert(join_order);
283-
}
284-
}
285-
explains.push(StringifiedPlan::new(
286-
PlanType::OptimizedPhysicalPlan {
287-
optimizer_name: "optd-all-join-orders".to_string(),
288-
},
289-
join_orders.iter().map(|x| x.to_string()).join("\n"),
290-
));
291-
explains.push(StringifiedPlan::new(
292-
PlanType::OptimizedPhysicalPlan {
293-
optimizer_name: "optd-all-logical-join-orders".to_string(),
294-
},
295-
logical_join_orders.iter().map(|x| x.to_string()).join("\n"),
296-
));
297-
}
264+
// if ENABLE_JOIN_ORDER {
265+
// let join_order = get_join_order(optimized_rel.clone());
266+
// explains.push(StringifiedPlan::new(
267+
// PlanType::OptimizedPhysicalPlan {
268+
// optimizer_name: "optd-join-order".to_string(),
269+
// },
270+
// if let Some(join_order) = join_order {
271+
// join_order.to_string()
272+
// } else {
273+
// "None".to_string()
274+
// },
275+
// ));
276+
// let bindings = optimizer
277+
// .optd_cascades_optimizer()
278+
// .get_all_group_bindings(group_id, true);
279+
// let mut join_orders = BTreeSet::new();
280+
// let mut logical_join_orders = BTreeSet::new();
281+
// for binding in bindings {
282+
// if let Some(join_order) = get_join_order(binding) {
283+
// logical_join_orders.insert(join_order.conv_into_logical_join_order());
284+
// join_orders.insert(join_order);
285+
// }
286+
// }
287+
// explains.push(StringifiedPlan::new(
288+
// PlanType::OptimizedPhysicalPlan {
289+
// optimizer_name: "optd-all-join-orders".to_string(),
290+
// },
291+
// join_orders.iter().map(|x| x.to_string()).join("\n"),
292+
// ));
293+
// explains.push(StringifiedPlan::new(
294+
// PlanType::OptimizedPhysicalPlan {
295+
// optimizer_name: "optd-all-logical-join-orders".to_string(),
296+
// },
297+
// logical_join_orders.iter().map(|x| x.to_string()).join("\n"),
298+
// ));
299+
// }
298300
}
299301
// println!(
300302
// "{} cost={}",

optd-datafusion-repr/src/cost/base_cost/agg.rs

+3-10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
1-
use std::sync::Arc;
2-
31
use optd_core::{
42
cascades::{CascadesOptimizer, RelNodeContext},
53
cost::Cost,
6-
rel_node::RelNode,
74
};
85
use serde::{de::DeserializeOwned, Serialize};
96

@@ -48,13 +45,9 @@ impl<
4845
) -> f64 {
4946
if let (Some(context), Some(optimizer)) = (context, optimizer) {
5047
let group_by_id = context.children_group_ids[2];
51-
let mut group_by_exprs: Vec<Arc<RelNode<OptRelNodeTyp>>> =
52-
optimizer.get_all_group_bindings(group_by_id, false);
53-
assert!(
54-
group_by_exprs.len() == 1,
55-
"ExprList expression should be the only expression in the GROUP BY group"
56-
);
57-
let group_by = group_by_exprs.pop().unwrap();
48+
let group_by = optimizer
49+
.get_predicate_binding(group_by_id)
50+
.expect("no expression found?");
5851
let group_by = ExprList::from_rel_node(group_by).unwrap();
5952
if group_by.is_empty() {
6053
1.0

0 commit comments

Comments
 (0)