Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 4096073

Browse files
authored
feat(df-repr): add back join order enumeration (#204)
ref https://github.com/cmu-db/optd/issues/194 after the memo table refactor, adding back a more efficient join order enumeration implementation. --------- Signed-off-by: Alex Chi <[email protected]>
1 parent ae425ed commit 4096073

File tree

14 files changed

+241
-161
lines changed

14 files changed

+241
-161
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

optd-core/src/cascades.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@ mod memo;
44
mod optimizer;
55
mod tasks;
66

7-
use memo::Memo;
8-
pub use optimizer::{CascadesOptimizer, GroupId, OptimizerProperties, RelNodeContext};
7+
pub use memo::Memo;
8+
pub use optimizer::{CascadesOptimizer, ExprId, GroupId, OptimizerProperties, RelNodeContext};
99
use tasks::Task;

optd-core/src/cascades/memo.rs

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ pub struct RelMemoNode<T: RelNodeTyp> {
2626
pub data: Option<Value>,
2727
}
2828

29+
impl<T: RelNodeTyp> RelMemoNode<T> {
30+
pub fn into_rel_node(self) -> RelNode<T> {
31+
RelNode {
32+
typ: self.typ,
33+
children: self
34+
.children
35+
.into_iter()
36+
.map(|x| Arc::new(RelNode::new_group(x)))
37+
.collect(),
38+
data: self.data,
39+
}
40+
}
41+
}
42+
2943
impl<T: RelNodeTyp> std::fmt::Display for RelMemoNode<T> {
3044
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
3145
write!(f, "({}", self.typ)?;
@@ -401,7 +415,7 @@ impl<T: RelNodeTyp> Memo<T> {
401415
}
402416

403417
/// Get the memoized representation of a node, only for debugging purpose
404-
pub(crate) fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef<T> {
418+
pub fn get_expr_memoed(&self, mut expr_id: ExprId) -> RelMemoNodeRef<T> {
405419
while let Some(new_expr_id) = self.dup_expr_mapping.get(&expr_id) {
406420
expr_id = *new_expr_id;
407421
}
@@ -411,7 +425,8 @@ impl<T: RelNodeTyp> Memo<T> {
411425
.clone()
412426
}
413427

414-
pub(crate) fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
428+
pub fn get_all_exprs_in_group(&self, group_id: GroupId) -> Vec<ExprId> {
429+
let group_id = self.reduce_group(group_id);
415430
let group = self.groups.get(&group_id).expect("group not found");
416431
let mut exprs = group.group_exprs.iter().copied().collect_vec();
417432
exprs.sort();

optd-core/src/cascades/optimizer.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,10 @@ impl<T: RelNodeTyp> CascadesOptimizer<T> {
369369
.map(|x| x.cost.0[0])
370370
.unwrap_or(0.0)
371371
}
372+
373+
pub fn memo(&self) -> &Memo<T> {
374+
&self.memo
375+
}
372376
}
373377

374378
impl<T: RelNodeTyp> Optimizer<T> for CascadesOptimizer<T> {

optd-datafusion-bridge/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@ anyhow = "1"
1515
async-recursion = "1"
1616
futures-lite = "2"
1717
futures-util = "0.3"
18+
itertools = "0.11"

optd-datafusion-bridge/src/lib.rs

Lines changed: 14 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@ use datafusion::{
1616
physical_plan::{displayable, explain::ExplainExec, ExecutionPlan},
1717
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
1818
};
19+
use itertools::Itertools;
1920
use optd_datafusion_repr::{
20-
plan_nodes::{
21-
ConstantType, OptRelNode, OptRelNodeRef, OptRelNodeTyp, PhysicalHashJoin,
22-
PhysicalNestedLoopJoin, PlanNode,
23-
},
21+
plan_nodes::{ConstantType, OptRelNode, PlanNode},
2422
properties::schema::Catalog,
25-
DatafusionOptimizer,
23+
DatafusionOptimizer, MemoExt,
2624
};
2725
use std::{
2826
collections::HashMap,
@@ -89,93 +87,6 @@ pub struct OptdQueryPlanner {
8987
pub optimizer: Arc<Mutex<Option<Box<DatafusionOptimizer>>>>,
9088
}
9189

92-
#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
93-
enum JoinOrder {
94-
Table(String),
95-
HashJoin(Box<Self>, Box<Self>),
96-
NestedLoopJoin(Box<Self>, Box<Self>),
97-
}
98-
99-
#[allow(dead_code)]
100-
impl JoinOrder {
101-
pub fn conv_into_logical_join_order(&self) -> LogicalJoinOrder {
102-
match self {
103-
JoinOrder::Table(name) => LogicalJoinOrder::Table(name.clone()),
104-
JoinOrder::HashJoin(left, right) => LogicalJoinOrder::Join(
105-
Box::new(left.conv_into_logical_join_order()),
106-
Box::new(right.conv_into_logical_join_order()),
107-
),
108-
JoinOrder::NestedLoopJoin(left, right) => LogicalJoinOrder::Join(
109-
Box::new(left.conv_into_logical_join_order()),
110-
Box::new(right.conv_into_logical_join_order()),
111-
),
112-
}
113-
}
114-
}
115-
116-
#[allow(unused)]
117-
#[derive(Debug, Eq, PartialEq, Hash, PartialOrd, Ord)]
118-
enum LogicalJoinOrder {
119-
Table(String),
120-
Join(Box<Self>, Box<Self>),
121-
}
122-
123-
#[allow(dead_code)]
124-
fn get_join_order(rel_node: OptRelNodeRef) -> Option<JoinOrder> {
125-
match rel_node.typ {
126-
OptRelNodeTyp::PhysicalHashJoin(_) => {
127-
let join = PhysicalHashJoin::from_rel_node(rel_node.clone()).unwrap();
128-
let left = get_join_order(join.left().into_rel_node())?;
129-
let right = get_join_order(join.right().into_rel_node())?;
130-
Some(JoinOrder::HashJoin(Box::new(left), Box::new(right)))
131-
}
132-
OptRelNodeTyp::PhysicalNestedLoopJoin(_) => {
133-
let join = PhysicalNestedLoopJoin::from_rel_node(rel_node.clone()).unwrap();
134-
let left = get_join_order(join.left().into_rel_node())?;
135-
let right = get_join_order(join.right().into_rel_node())?;
136-
Some(JoinOrder::NestedLoopJoin(Box::new(left), Box::new(right)))
137-
}
138-
OptRelNodeTyp::PhysicalScan => {
139-
let scan =
140-
optd_datafusion_repr::plan_nodes::PhysicalScan::from_rel_node(rel_node).unwrap();
141-
Some(JoinOrder::Table(scan.table().to_string()))
142-
}
143-
_ => {
144-
for child in &rel_node.children {
145-
if let Some(res) = get_join_order(child.clone()) {
146-
return Some(res);
147-
}
148-
}
149-
None
150-
}
151-
}
152-
}
153-
154-
impl std::fmt::Display for LogicalJoinOrder {
155-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
156-
match self {
157-
LogicalJoinOrder::Table(name) => write!(f, "{}", name),
158-
LogicalJoinOrder::Join(left, right) => {
159-
write!(f, "(Join {} {})", left, right)
160-
}
161-
}
162-
}
163-
}
164-
165-
impl std::fmt::Display for JoinOrder {
166-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167-
match self {
168-
JoinOrder::Table(name) => write!(f, "{}", name),
169-
JoinOrder::HashJoin(left, right) => {
170-
write!(f, "(HashJoin {} {})", left, right)
171-
}
172-
JoinOrder::NestedLoopJoin(left, right) => {
173-
write!(f, "(NLJ {} {})", left, right)
174-
}
175-
}
176-
}
177-
}
178-
17990
impl OptdQueryPlanner {
18091
pub fn enable_adaptive(&self) {
18192
self.optimizer
@@ -247,7 +158,7 @@ impl OptdQueryPlanner {
247158
}
248159
}
249160

250-
let (_, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;
161+
let (group_id, optimized_rel, meta) = optimizer.cascades_optimize(optd_rel)?;
251162

252163
if let Some(explains) = &mut explains {
253164
explains.push(StringifiedPlan::new(
@@ -258,52 +169,17 @@ impl OptdQueryPlanner {
258169
.unwrap()
259170
.explain_to_string(if verbose { Some(&meta) } else { None }),
260171
));
261-
262-
// const ENABLE_JOIN_ORDER: bool = false;
263-
264-
// if ENABLE_JOIN_ORDER {
265-
// let join_order = get_join_order(optimized_rel.clone());
266-
// explains.push(StringifiedPlan::new(
267-
// PlanType::OptimizedPhysicalPlan {
268-
// optimizer_name: "optd-join-order".to_string(),
269-
// },
270-
// if let Some(join_order) = join_order {
271-
// join_order.to_string()
272-
// } else {
273-
// "None".to_string()
274-
// },
275-
// ));
276-
// let bindings = optimizer
277-
// .optd_cascades_optimizer()
278-
// .get_all_group_bindings(group_id, true);
279-
// let mut join_orders = BTreeSet::new();
280-
// let mut logical_join_orders = BTreeSet::new();
281-
// for binding in bindings {
282-
// if let Some(join_order) = get_join_order(binding) {
283-
// logical_join_orders.insert(join_order.conv_into_logical_join_order());
284-
// join_orders.insert(join_order);
285-
// }
286-
// }
287-
// explains.push(StringifiedPlan::new(
288-
// PlanType::OptimizedPhysicalPlan {
289-
// optimizer_name: "optd-all-join-orders".to_string(),
290-
// },
291-
// join_orders.iter().map(|x| x.to_string()).join("\n"),
292-
// ));
293-
// explains.push(StringifiedPlan::new(
294-
// PlanType::OptimizedPhysicalPlan {
295-
// optimizer_name: "optd-all-logical-join-orders".to_string(),
296-
// },
297-
// logical_join_orders.iter().map(|x| x.to_string()).join("\n"),
298-
// ));
299-
// }
172+
let join_orders = optimizer
173+
.optd_cascades_optimizer()
174+
.memo()
175+
.enumerate_join_order(group_id);
176+
explains.push(StringifiedPlan::new(
177+
PlanType::OptimizedPhysicalPlan {
178+
optimizer_name: "optd-all-logical-join-orders".to_string(),
179+
},
180+
join_orders.iter().map(|x| x.to_string()).join("\n"),
181+
));
300182
}
301-
// println!(
302-
// "{} cost={}",
303-
// get_join_order(optimized_rel.clone()).unwrap(),
304-
// optimizer.optd_optimizer().get_cost_of(group_id)
305-
// );
306-
// optimizer.dump(Some(group_id));
307183
ctx.optimizer = Some(&optimizer);
308184
let physical_plan = ctx.conv_from_optd(optimized_rel, meta).await?;
309185
if let Some(explains) = &mut explains {

optd-datafusion-repr/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ use crate::rules::{
3434
DepInitialDistinct, DepJoinEliminateAtScan, DepJoinPastAgg, DepJoinPastFilter, DepJoinPastProj,
3535
};
3636

37+
pub use memo_ext::{LogicalJoinOrder, MemoExt};
38+
3739
pub mod cost;
3840
mod explain;
41+
mod memo_ext;
3942
pub mod plan_nodes;
4043
pub mod properties;
4144
pub mod rules;
4245
#[cfg(test)]
4346
mod testing;
44-
// mod expand;
4547

4648
pub struct DatafusionOptimizer {
4749
heuristic_optimizer: HeuristicsOptimizer<OptRelNodeTyp>,

0 commit comments

Comments
 (0)