Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 3140324

Browse files
committed
prune based on upper bound
Signed-off-by: Alex Chi Z <[email protected]>
1 parent 81a2e80 commit 3140324

File tree

7 files changed

+63
-21
lines changed

7 files changed

+63
-21
lines changed

optd-core/src/cascades/memo.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ impl<T: NodeType> NaiveMemo<T> {
352352
}
353353

354354
fn verify_integrity(&self) {
355-
if cfg!(debug_assertions) {
355+
if false {
356356
let num_of_exprs = self.expr_id_to_expr_node.len();
357357
assert_eq!(num_of_exprs, self.expr_node_to_expr_id.len());
358358
assert_eq!(num_of_exprs, self.expr_id_to_group_id.len());

optd-core/src/cascades/optimizer.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ impl<T: NodeType, M: Memo<T>> CascadesOptimizer<T, M> {
248248
fn fire_optimize_tasks(&mut self, group_id: GroupId) -> Result<()> {
249249
trace!(event = "fire_optimize_tasks", root_group_id = %group_id);
250250
self.tasks
251-
.push_back(Box::new(OptimizeGroupTask::new(group_id)));
251+
.push_back(Box::new(OptimizeGroupTask::new(group_id, None)));
252252
// get the task from the stack
253253
self.ctx.budget_used = false;
254254
let plan_space_begin = self.memo.estimated_plan_space();

optd-core/src/cascades/tasks/apply_rule.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,21 @@ pub struct ApplyRuleTask {
2121
rule_id: RuleId,
2222
expr_id: ExprId,
2323
exploring: bool,
24+
upper_bound: Option<f64>,
2425
}
2526

2627
impl ApplyRuleTask {
27-
pub fn new(rule_id: RuleId, expr_id: ExprId, exploring: bool) -> Self {
28+
pub fn new(
29+
rule_id: RuleId,
30+
expr_id: ExprId,
31+
exploring: bool,
32+
upper_bound: Option<f64>,
33+
) -> Self {
2834
Self {
2935
rule_id,
3036
expr_id,
3137
exploring,
38+
upper_bound,
3239
}
3340
}
3441
}
@@ -181,13 +188,14 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for ApplyRuleTask {
181188
let typ = expr.unwrap_typ();
182189
if typ.is_logical() {
183190
tasks.push(
184-
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring))
191+
Box::new(OptimizeExpressionTask::new(expr_id, self.exploring, self.upper_bound))
185192
as Box<dyn Task<T, M>>,
186193
);
187194
} else {
188195
tasks.push(Box::new(OptimizeInputsTask::new(
189196
expr_id,
190197
!optimizer.prop.disable_pruning,
198+
self.upper_bound
191199
)) as Box<dyn Task<T, M>>);
192200
}
193201
optimizer.unmark_expr_explored(expr_id);

optd-core/src/cascades/tasks/explore_group.rs

+7-3
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@ use crate::nodes::NodeType;
1414

1515
pub struct ExploreGroupTask {
1616
group_id: GroupId,
17+
upper_bound: Option<f64>,
1718
}
1819

1920
impl ExploreGroupTask {
20-
pub fn new(group_id: GroupId) -> Self {
21-
Self { group_id }
21+
pub fn new(group_id: GroupId, upper_bound: Option<f64>) -> Self {
22+
Self {
23+
group_id,
24+
upper_bound,
25+
}
2226
}
2327
}
2428

@@ -36,7 +40,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for ExploreGroupTask {
3640
let typ = optimizer.get_expr_memoed(expr).typ.clone();
3741
if typ.is_logical() {
3842
tasks
39-
.push(Box::new(OptimizeExpressionTask::new(expr, true)) as Box<dyn Task<T, M>>);
43+
.push(Box::new(OptimizeExpressionTask::new(expr, true, self.upper_bound)) as Box<dyn Task<T, M>>);
4044
}
4145
}
4246
optimizer.mark_group_explored(self.group_id);

optd-core/src/cascades/tasks/optimize_expression.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,16 @@ use crate::rules::RuleMatcher;
1616
pub struct OptimizeExpressionTask {
1717
expr_id: ExprId,
1818
exploring: bool,
19+
upper_bound: Option<f64>,
1920
}
2021

2122
impl OptimizeExpressionTask {
22-
pub fn new(expr_id: ExprId, exploring: bool) -> Self {
23-
Self { expr_id, exploring }
23+
pub fn new(expr_id: ExprId, exploring: bool, upper_bound: Option<f64>) -> Self {
24+
Self {
25+
expr_id,
26+
exploring,
27+
upper_bound,
28+
}
2429
}
2530
}
2631

@@ -53,12 +58,12 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeExpressionTask {
5358
}
5459
if top_matches(rule.matcher(), expr.typ.clone()) {
5560
tasks.push(
56-
Box::new(ApplyRuleTask::new(rule_id, self.expr_id, self.exploring))
61+
Box::new(ApplyRuleTask::new(rule_id, self.expr_id, self.exploring, self.upper_bound))
5762
as Box<dyn Task<T, M>>,
5863
);
5964
for &input_group_id in &expr.children {
6065
tasks.push(
61-
Box::new(ExploreGroupTask::new(input_group_id)) as Box<dyn Task<T, M>>
66+
Box::new(ExploreGroupTask::new(input_group_id, self.upper_bound)) as Box<dyn Task<T, M>>
6267
);
6368
}
6469
}

optd-core/src/cascades/tasks/optimize_group.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,15 @@ use crate::nodes::NodeType;
1515

1616
pub struct OptimizeGroupTask {
1717
group_id: GroupId,
18+
upper_bound: Option<f64>,
1819
}
1920

2021
impl OptimizeGroupTask {
21-
pub fn new(group_id: GroupId) -> Self {
22-
Self { group_id }
22+
pub fn new(group_id: GroupId, upper_bound: Option<f64>) -> Self {
23+
Self {
24+
group_id,
25+
upper_bound,
26+
}
2327
}
2428
}
2529

@@ -37,7 +41,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeGroupTask {
3741
for &expr in &exprs {
3842
let typ = optimizer.get_expr_memoed(expr).typ.clone();
3943
if typ.is_logical() {
40-
tasks.push(Box::new(OptimizeExpressionTask::new(expr, false)) as Box<dyn Task<T, M>>);
44+
tasks.push(Box::new(OptimizeExpressionTask::new(expr, false, self.upper_bound)) as Box<dyn Task<T, M>>);
4145
}
4246
}
4347
for &expr in &exprs {
@@ -46,6 +50,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeGroupTask {
4650
tasks.push(Box::new(OptimizeInputsTask::new(
4751
expr,
4852
!optimizer.prop.disable_pruning,
53+
self.upper_bound
4954
)) as Box<dyn Task<T, M>>);
5055
}
5156
}

optd-core/src/cascades/tasks/optimize_inputs.rs

+27-7
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,25 @@ pub struct OptimizeInputsTask {
4242
expr_id: ExprId,
4343
continue_from: Option<ContinueTask>,
4444
pruning: bool,
45+
upper_bound: Option<f64>,
4546
}
4647

4748
impl OptimizeInputsTask {
48-
pub fn new(expr_id: ExprId, pruning: bool) -> Self {
49+
pub fn new(expr_id: ExprId, pruning: bool, upper_bound: Option<f64>) -> Self {
4950
Self {
5051
expr_id,
5152
continue_from: None,
5253
pruning,
54+
upper_bound,
5355
}
5456
}
5557

56-
fn continue_from(&self, cont: ContinueTask, pruning: bool) -> Self {
58+
fn continue_from(&self, cont: ContinueTask, pruning: bool, upper_bound: Option<f64>) -> Self {
5759
Self {
5860
expr_id: self.expr_id,
5961
continue_from: Some(cont),
6062
pruning,
63+
upper_bound,
6164
}
6265
}
6366

@@ -153,6 +156,19 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
153156

154157
trace!(event = "task_begin", task = "optimize_inputs", expr_id = %self.expr_id, continue_from = %ContinueTaskDisplay(&self.continue_from), total_children = %children_group_ids.len());
155158

159+
let upper_bound = if self.pruning {
160+
if let Some(upper_bound) = self.upper_bound {
161+
Some(upper_bound)
162+
} else if let Some(winner) = optimizer.get_group_info(group_id).winner.as_full_winner()
163+
{
164+
Some(winner.total_weighted_cost)
165+
} else {
166+
None
167+
}
168+
} else {
169+
None
170+
};
171+
156172
if let Some(ContinueTask {
157173
next_group_idx,
158174
return_from_optimize_group,
@@ -219,9 +235,9 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
219235
winner_weighted_cost = %trace_fmt(&group_info.winner),
220236
current_processing = %next_group_idx,
221237
total_child_groups = %children_group_ids.len());
222-
if let Some(winner) = group_info.winner.as_full_winner() {
238+
if let Some(upper_bound) = upper_bound {
223239
let cost_so_far = cost.weighted_cost(&total_cost);
224-
if winner.total_weighted_cost <= cost_so_far {
240+
if upper_bound <= cost_so_far {
225241
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "pruned");
226242
return Ok(vec![]);
227243
}
@@ -232,7 +248,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
232248
let child_group_id = children_group_ids[next_group_idx];
233249
let group_idx = next_group_idx;
234250
let child_group_info = optimizer.get_group_info(child_group_id);
235-
if !child_group_info.winner.has_full_winner() {
251+
let Some(child_winner) = child_group_info.winner.as_full_winner() else {
236252
if !return_from_optimize_group {
237253
trace!(event = "task_yield", task = "optimize_inputs", expr_id = %self.expr_id, group_idx = %group_idx, yield_to = "optimize_group", optimize_group_id = %child_group_id);
238254
return Ok(vec![
@@ -242,22 +258,25 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
242258
return_from_optimize_group: true,
243259
},
244260
self.pruning,
261+
upper_bound,
245262
)) as Box<dyn Task<T, M>>,
246-
Box::new(OptimizeGroupTask::new(child_group_id)) as Box<dyn Task<T, M>>,
263+
Box::new(OptimizeGroupTask::new(child_group_id, upper_bound))
264+
as Box<dyn Task<T, M>>,
247265
]);
248266
} else {
249267
self.update_winner_impossible(optimizer);
250268
trace!(event = "task_finish", task = "optimize_inputs", expr_id = %self.expr_id, result = "impossible");
251269
return Ok(vec![]);
252270
}
253-
}
271+
};
254272
trace!(event = "task_yield", task = "optimize_inputs", expr_id = %self.expr_id, group_idx = %group_idx, yield_to = "next_optimize_input");
255273
Ok(vec![Box::new(self.continue_from(
256274
ContinueTask {
257275
next_group_idx: group_idx + 1,
258276
return_from_optimize_group: false,
259277
},
260278
self.pruning,
279+
upper_bound.map(|bound| bound - child_winner.total_weighted_cost),
261280
)) as Box<dyn Task<T, M>>])
262281
} else {
263282
self.update_winner(input_statistics_ref, operation_cost, total_cost, optimizer);
@@ -272,6 +291,7 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
272291
return_from_optimize_group: false,
273292
},
274293
self.pruning,
294+
upper_bound,
275295
)) as Box<dyn Task<T, M>>])
276296
}
277297
}

0 commit comments

Comments
 (0)