Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 2dd2a31

Browse files
authored
refactor(core): rm option around cost models (#243)
`Option` was introduced during a transition time where we thought cost model can compute the cost solely based on the children cost but it turned out that we need the optimizer for a few derived logical properties. We can drop them now. Signed-off-by: Alex Chi <[email protected]>
1 parent f4b62f3 commit 2dd2a31

File tree

7 files changed

+51
-69
lines changed

7 files changed

+51
-69
lines changed

docs/src/cost_model.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pub trait CostModel<T: RelNodeTyp>: 'static + Send + Sync {
1111
node: &T,
1212
data: &Option<Value>,
1313
children: &[Cost],
14-
context: Option<RelNodeContext>,
14+
context: RelNodeContext,
1515
) -> Cost;
1616
}
1717
```

optd-core/src/cascades/tasks/optimize_inputs.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,12 @@ impl OptimizeInputsTask {
110110
.iter()
111111
.map(|x| x.expect("child winner should always have statistics?"))
112112
.collect::<Vec<_>>(),
113-
Some(RelNodeContext {
113+
RelNodeContext {
114114
group_id,
115115
expr_id: self.expr_id,
116116
children_group_ids: expr.children.clone(),
117-
}),
118-
Some(optimizer),
117+
},
118+
optimizer,
119119
);
120120
optimizer.update_group_info(
121121
group_id,
@@ -197,8 +197,8 @@ impl<T: NodeType, M: Memo<T>> Task<T, M> for OptimizeInputsTask {
197197
&expr.typ,
198198
&preds,
199199
&input_statistics_ref,
200-
Some(context.clone()),
201-
Some(optimizer),
200+
context.clone(),
201+
optimizer,
202202
);
203203
let total_cost = cost.sum(&operation_cost, &input_cost);
204204

optd-core/src/cost.rs

+10-6
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,29 @@ pub struct Statistics(pub Box<dyn std::any::Any + Send + Sync + 'static>);
1616
pub struct Cost(pub Vec<f64>);
1717

1818
pub trait CostModel<T: NodeType, M: Memo<T>>: 'static + Send + Sync {
19-
/// Compute the cost of a single operation
19+
/// Compute the cost of a single operation. `RelNodeContext` might be
20+
/// optional in the future when we implement physical property enforcers.
21+
/// If we have not decided the winner for a child group yet, the statistics
22+
/// for that group will be `None`.
2023
#[allow(clippy::too_many_arguments)]
2124
fn compute_operation_cost(
2225
&self,
2326
node: &T,
2427
predicates: &[ArcPredNode<T>],
2528
children_stats: &[Option<&Statistics>],
26-
context: Option<RelNodeContext>,
27-
optimizer: Option<&CascadesOptimizer<T, M>>,
29+
context: RelNodeContext,
30+
optimizer: &CascadesOptimizer<T, M>,
2831
) -> Cost;
2932

30-
/// Derive the statistics of a single operation
33+
/// Derive the statistics of a single operation. `RelNodeContext` might be
34+
/// optional in the future when we implement physical property enforcers.
3135
fn derive_statistics(
3236
&self,
3337
node: &T,
3438
predicates: &[ArcPredNode<T>],
3539
children_stats: &[&Statistics],
36-
context: Option<RelNodeContext>,
37-
optimizer: Option<&CascadesOptimizer<T, M>>,
40+
context: RelNodeContext,
41+
optimizer: &CascadesOptimizer<T, M>,
3842
) -> Statistics;
3943

4044
fn explain_cost(&self, cost: &Cost) -> String;

optd-datafusion-repr-adv-cost/src/lib.rs

+21-41
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
6161
node: &DfNodeType,
6262
predicates: &[ArcDfPredNode],
6363
children_stats: &[Option<&Statistics>],
64-
context: Option<RelNodeContext>,
65-
optimizer: Option<&CascadesOptimizer<DfNodeType>>,
64+
context: RelNodeContext,
65+
optimizer: &CascadesOptimizer<DfNodeType>,
6666
) -> Cost {
6767
self.base_model
6868
.compute_operation_cost(node, predicates, children_stats, context, optimizer)
@@ -73,11 +73,9 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
7373
node: &DfNodeType,
7474
predicates: &[ArcDfPredNode],
7575
children_stats: &[&Statistics],
76-
context: Option<RelNodeContext>,
77-
optimizer: Option<&CascadesOptimizer<DfNodeType>>,
76+
context: RelNodeContext,
77+
optimizer: &CascadesOptimizer<DfNodeType>,
7878
) -> Statistics {
79-
let context = context.as_ref();
80-
let optimizer = optimizer.as_ref();
8179
let row_cnts = children_stats
8280
.iter()
8381
.map(|child| DfCostModel::row_cnt(child))
@@ -100,12 +98,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
10098
DfCostModel::stat(row_cnt)
10199
}
102100
DfNodeType::PhysicalFilter => {
103-
let output_schema = optimizer
104-
.unwrap()
105-
.get_schema_of(context.unwrap().group_id.into());
106-
let output_column_ref = optimizer
107-
.unwrap()
108-
.get_column_ref_of(context.unwrap().group_id.into());
101+
let output_schema = optimizer.get_schema_of(context.group_id.into());
102+
let output_column_ref = optimizer.get_column_ref_of(context.group_id.into());
109103
let row_cnt = self.stats.get_filter_row_cnt(
110104
row_cnts[0],
111105
output_schema,
@@ -115,18 +109,12 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
115109
DfCostModel::stat(row_cnt)
116110
}
117111
DfNodeType::PhysicalNestedLoopJoin(join_typ) => {
118-
let output_schema = optimizer
119-
.unwrap()
120-
.get_schema_of(context.unwrap().group_id.into());
121-
let output_column_ref = optimizer
122-
.unwrap()
123-
.get_column_ref_of(context.unwrap().group_id.into());
124-
let left_column_ref = optimizer
125-
.unwrap()
126-
.get_column_ref_of(context.unwrap().children_group_ids[0].into());
127-
let right_column_ref = optimizer
128-
.unwrap()
129-
.get_column_ref_of(context.unwrap().children_group_ids[1].into());
112+
let output_schema = optimizer.get_schema_of(context.group_id.into());
113+
let output_column_ref = optimizer.get_column_ref_of(context.group_id.into());
114+
let left_column_ref =
115+
optimizer.get_column_ref_of(context.children_group_ids[0].into());
116+
let right_column_ref =
117+
optimizer.get_column_ref_of(context.children_group_ids[1].into());
130118
let row_cnt = self.stats.get_nlj_row_cnt(
131119
*join_typ,
132120
row_cnts[0],
@@ -140,18 +128,12 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
140128
DfCostModel::stat(row_cnt)
141129
}
142130
DfNodeType::PhysicalHashJoin(join_typ) => {
143-
let output_schema = optimizer
144-
.unwrap()
145-
.get_schema_of(context.unwrap().group_id.into());
146-
let output_column_ref = optimizer
147-
.unwrap()
148-
.get_column_ref_of(context.unwrap().group_id.into());
149-
let left_column_ref = optimizer
150-
.unwrap()
151-
.get_column_ref_of(context.unwrap().children_group_ids[0].into());
152-
let right_column_ref = optimizer
153-
.unwrap()
154-
.get_column_ref_of(context.unwrap().children_group_ids[1].into());
131+
let output_schema = optimizer.get_schema_of(context.group_id.into());
132+
let output_column_ref = optimizer.get_column_ref_of(context.group_id.into());
133+
let left_column_ref =
134+
optimizer.get_column_ref_of(context.children_group_ids[0].into());
135+
let right_column_ref =
136+
optimizer.get_column_ref_of(context.children_group_ids[1].into());
155137
let row_cnt = self.stats.get_hash_join_row_cnt(
156138
*join_typ,
157139
row_cnts[0],
@@ -166,9 +148,7 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
166148
DfCostModel::stat(row_cnt)
167149
}
168150
DfNodeType::PhysicalAgg => {
169-
let output_column_ref = optimizer
170-
.unwrap()
171-
.get_column_ref_of(context.unwrap().group_id.into());
151+
let output_column_ref = optimizer.get_column_ref_of(context.group_id.into());
172152
let row_cnt = self
173153
.stats
174154
.get_agg_row_cnt(predicates[1].clone(), output_column_ref);
@@ -178,8 +158,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdvancedCostModel {
178158
node,
179159
predicates,
180160
children_stats,
181-
context.cloned(),
182-
optimizer.copied(),
161+
context,
162+
optimizer,
183163
),
184164
}
185165
}

optd-datafusion-repr/src/cost/adaptive_cost.rs

+6-8
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ pub struct AdaptiveCostModel {
2828
}
2929

3030
impl AdaptiveCostModel {
31-
fn get_row_cnt(&self, context: &Option<RelNodeContext>) -> f64 {
31+
fn get_row_cnt(&self, context: &RelNodeContext) -> f64 {
3232
let guard = self.runtime_row_cnt.lock().unwrap();
33-
if let Some((runtime_row_cnt, iter)) =
34-
guard.history.get(&context.as_ref().unwrap().group_id)
35-
{
33+
if let Some((runtime_row_cnt, iter)) = guard.history.get(&context.group_id) {
3634
if *iter + self.decay >= guard.iter_cnt {
3735
return (*runtime_row_cnt).max(1) as f64;
3836
}
@@ -67,8 +65,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdaptiveCostModel {
6765
node: &DfNodeType,
6866
predicates: &[ArcDfPredNode],
6967
children: &[Option<&Statistics>],
70-
context: Option<RelNodeContext>,
71-
optimizer: Option<&CascadesOptimizer<DfNodeType>>,
68+
context: RelNodeContext,
69+
optimizer: &CascadesOptimizer<DfNodeType>,
7270
) -> Cost {
7371
if let DfNodeType::PhysicalScan = node {
7472
let row_cnt = self.get_row_cnt(&context);
@@ -83,8 +81,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for AdaptiveCostModel {
8381
node: &DfNodeType,
8482
predicates: &[ArcDfPredNode],
8583
children: &[&Statistics],
86-
context: Option<RelNodeContext>,
87-
optimizer: Option<&CascadesOptimizer<DfNodeType>>,
84+
context: RelNodeContext,
85+
optimizer: &CascadesOptimizer<DfNodeType>,
8886
) -> Statistics {
8987
if let DfNodeType::PhysicalScan = node {
9088
let row_cnt = self.get_row_cnt(&context);

optd-datafusion-repr/src/cost/base_cost.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for DfCostModel {
8989
node: &DfNodeType,
9090
predicates: &[ArcDfPredNode],
9191
children: &[&Statistics],
92-
_context: Option<RelNodeContext>,
93-
_optimizer: Option<&CascadesOptimizer<DfNodeType>>,
92+
_context: RelNodeContext,
93+
_optimizer: &CascadesOptimizer<DfNodeType>,
9494
) -> Statistics {
9595
match node {
9696
DfNodeType::PhysicalScan => {
@@ -132,8 +132,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for DfCostModel {
132132
node: &DfNodeType,
133133
predicates: &[ArcDfPredNode],
134134
children: &[Option<&Statistics>],
135-
_context: Option<RelNodeContext>,
136-
_optimizer: Option<&CascadesOptimizer<DfNodeType>>,
135+
_context: RelNodeContext,
136+
_optimizer: &CascadesOptimizer<DfNodeType>,
137137
) -> Cost {
138138
let row_cnts = children
139139
.iter()

optd-datafusion-repr/src/testing/dummy_cost.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for DummyCostModel {
1919
_: &DfNodeType,
2020
_: &[ArcDfPredNode],
2121
_: &[Option<&Statistics>],
22-
_: Option<RelNodeContext>,
23-
_: Option<&CascadesOptimizer<DfNodeType>>,
22+
_: RelNodeContext,
23+
_: &CascadesOptimizer<DfNodeType>,
2424
) -> Cost {
2525
Cost(vec![1.0])
2626
}
@@ -31,8 +31,8 @@ impl CostModel<DfNodeType, NaiveMemo<DfNodeType>> for DummyCostModel {
3131
_: &DfNodeType,
3232
_: &[ArcDfPredNode],
3333
_: &[&Statistics],
34-
_: Option<RelNodeContext>,
35-
_: Option<&CascadesOptimizer<DfNodeType>>,
34+
_: RelNodeContext,
35+
_: &CascadesOptimizer<DfNodeType>,
3636
) -> Statistics {
3737
Statistics(Box::new(()))
3838
}

0 commit comments

Comments
 (0)