Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit 75253cf

Browse files
authored
Separate advanced cost model from the previous basic cost model (#200)
Tested only briefly, quick and dirty
1 parent 68c697d commit 75253cf

31 files changed

+1092
-643
lines changed

Cargo.lock

+23-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ members = [
88
"optd-adaptive-demo",
99
"optd-gungnir",
1010
"optd-perfbench",
11+
"optd-datafusion-repr-adv-cost",
1112
]
1213
resolver = "2"

datafusion-optd-cli/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ tokio = { version = "1.24", features = [
5858
] }
5959
url = "2.2"
6060
optd-datafusion-bridge = { path = "../optd-datafusion-bridge" }
61+
optd-datafusion-repr-adv-cost = { path = "../optd-datafusion-repr-adv-cost" }
6162
optd-datafusion-repr = { path = "../optd-datafusion-repr" }
6263
tracing-subscriber = "0.3"
6364
tracing = "0.1"

datafusion-optd-cli/src/main.rs

+16-6
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ use datafusion_optd_cli::{
3030
};
3131
use mimalloc::MiMalloc;
3232
use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
33-
use optd_datafusion_repr::cost::BaseTableStats;
3433
use optd_datafusion_repr::DatafusionOptimizer;
34+
use optd_datafusion_repr_adv_cost::adv_cost::stats::BaseTableStats;
3535
use std::collections::HashMap;
3636
use std::env;
3737
use std::path::Path;
@@ -141,6 +141,9 @@ struct Args {
141141

142142
#[clap(long, help = "Turn on adaptive optimization")]
143143
enable_adaptive: bool,
144+
145+
#[clap(long, help = "Use advanced cost model")]
146+
adv_cost: bool,
144147
}
145148

146149
#[tokio::main]
@@ -204,11 +207,18 @@ pub async fn main() -> Result<()> {
204207
state = state.with_physical_optimizer_rules(vec![]);
205208
}
206209
// use optd-bridge query planner
207-
let optimizer = DatafusionOptimizer::new_physical(
208-
Arc::new(DatafusionCatalog::new(state.catalog_list())),
209-
BaseTableStats::default(),
210-
args.enable_adaptive,
211-
);
210+
let optimizer = if args.adv_cost {
211+
optd_datafusion_repr_adv_cost::new_physical_adv_cost(
212+
Arc::new(DatafusionCatalog::new(state.catalog_list())),
213+
BaseTableStats::default(),
214+
args.enable_adaptive,
215+
)
216+
} else {
217+
DatafusionOptimizer::new_physical(
218+
Arc::new(DatafusionCatalog::new(state.catalog_list())),
219+
args.enable_adaptive,
220+
)
221+
};
212222
state = state.with_query_planner(Arc::new(OptdQueryPlanner::new(optimizer)));
213223
SessionContext::new_with_state(state)
214224
};

optd-adaptive-demo/src/bin/optd-adaptive-three-join.rs

-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use datafusion_optd_cli::{
1010
};
1111
use mimalloc::MiMalloc;
1212
use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
13-
use optd_datafusion_repr::cost::BaseTableStats;
1413
use optd_datafusion_repr::DatafusionOptimizer;
1514
use rand::{thread_rng, Rng};
1615
use std::sync::Arc;
@@ -30,7 +29,6 @@ async fn main() -> Result<()> {
3029
SessionState::new_with_config_rt(session_config.clone(), Arc::new(runtime_env));
3130
let mut optimizer: DatafusionOptimizer = DatafusionOptimizer::new_physical(
3231
Arc::new(DatafusionCatalog::new(state.catalog_list())),
33-
BaseTableStats::default(),
3432
true,
3533
);
3634
optimizer.optd_optimizer_mut().prop.partial_explore_iter = None;

optd-adaptive-demo/src/bin/optd-adaptive-tpch-q8.rs

-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ use datafusion_optd_cli::{
1010
};
1111
use mimalloc::MiMalloc;
1212
use optd_datafusion_bridge::{DatafusionCatalog, OptdQueryPlanner};
13-
use optd_datafusion_repr::cost::BaseTableStats;
1413
use optd_datafusion_repr::DatafusionOptimizer;
1514
use std::sync::Arc;
1615
use std::time::Duration;
@@ -31,7 +30,6 @@ async fn main() -> Result<()> {
3130
SessionState::new_with_config_rt(session_config.clone(), Arc::new(runtime_env));
3231
let optimizer: DatafusionOptimizer = DatafusionOptimizer::new_physical(
3332
Arc::new(DatafusionCatalog::new(state.catalog_list())),
34-
BaseTableStats::default(),
3533
true,
3634
);
3735
// clean up optimizer rules so that we can plug in our own optimizer
+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
[package]
2+
name = "optd-datafusion-repr-adv-cost"
3+
version = "0.1.0"
4+
edition = "2021"
5+
6+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
7+
8+
[dependencies]
9+
anyhow = "1"
10+
arrow-schema = "47.0.0"
11+
assert_approx_eq = "1.1.0"
12+
datafusion = "32.0.0"
13+
ordered-float = "4"
14+
optd-datafusion-repr = { path = "../optd-datafusion-repr" }
15+
optd-core = { path = "../optd-core" }
16+
serde = { version = "1.0", features = ["derive"] }
17+
rayon = "1.10"
18+
itertools = "0.11"
19+
test-case = "3.3"
20+
tracing = "0.1"
21+
tracing-subscriber = "0.3"
22+
optd-gungnir = { path = "../optd-gungnir" }
23+
serde_with = {version = "3.7.0", features = ["json"]}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
use std::sync::{Arc, Mutex};
2+
3+
use crate::adv_cost::OptCostModel;
4+
use optd_core::{
5+
cascades::{CascadesOptimizer, RelNodeContext},
6+
cost::{Cost, CostModel},
7+
rel_node::{RelNode, Value},
8+
};
9+
use optd_datafusion_repr::{
10+
cost::adaptive_cost::RuntimeAdaptionStorageInner, plan_nodes::OptRelNodeTyp,
11+
};
12+
use serde::{de::DeserializeOwned, Serialize};
13+
14+
use super::adv_cost::stats::{
15+
BaseTableStats, DataFusionDistribution, DataFusionMostCommonValues, Distribution,
16+
MostCommonValues,
17+
};
18+
19+
pub type RuntimeAdaptionStorage = Arc<Mutex<RuntimeAdaptionStorageInner>>;
20+
pub type DataFusionAdaptiveCostModel =
21+
AdaptiveCostModel<DataFusionMostCommonValues, DataFusionDistribution>;
22+
23+
pub const DEFAULT_DECAY: usize = 50;
24+
25+
pub struct AdaptiveCostModel<
26+
M: MostCommonValues + Serialize + DeserializeOwned,
27+
D: Distribution + Serialize + DeserializeOwned,
28+
> {
29+
runtime_row_cnt: RuntimeAdaptionStorage,
30+
base_model: OptCostModel<M, D>,
31+
decay: usize,
32+
}
33+
34+
impl<
35+
M: MostCommonValues + Serialize + DeserializeOwned,
36+
D: Distribution + Serialize + DeserializeOwned,
37+
> CostModel<OptRelNodeTyp> for AdaptiveCostModel<M, D>
38+
{
39+
fn explain(&self, cost: &Cost) -> String {
40+
self.base_model.explain(cost)
41+
}
42+
43+
fn accumulate(&self, total_cost: &mut Cost, cost: &Cost) {
44+
self.base_model.accumulate(total_cost, cost)
45+
}
46+
47+
fn zero(&self) -> Cost {
48+
self.base_model.zero()
49+
}
50+
51+
fn compute_cost(
52+
&self,
53+
node: &OptRelNodeTyp,
54+
data: &Option<Value>,
55+
children: &[Cost],
56+
context: Option<RelNodeContext>,
57+
optimizer: Option<&CascadesOptimizer<OptRelNodeTyp>>,
58+
) -> Cost {
59+
if let OptRelNodeTyp::PhysicalScan = node {
60+
let guard = self.runtime_row_cnt.lock().unwrap();
61+
if let Some((runtime_row_cnt, iter)) =
62+
guard.history.get(&context.as_ref().unwrap().group_id)
63+
{
64+
if *iter + self.decay >= guard.iter_cnt {
65+
let runtime_row_cnt = (*runtime_row_cnt).max(1) as f64;
66+
return OptCostModel::<M, D>::cost(runtime_row_cnt, 0.0, runtime_row_cnt);
67+
}
68+
}
69+
}
70+
let (mut row_cnt, compute_cost, io_cost) = OptCostModel::<M, D>::cost_tuple(
71+
&self
72+
.base_model
73+
.compute_cost(node, data, children, context.clone(), optimizer),
74+
);
75+
if let Some(context) = context {
76+
let guard = self.runtime_row_cnt.lock().unwrap();
77+
if let Some((runtime_row_cnt, iter)) = guard.history.get(&context.group_id) {
78+
if *iter + self.decay >= guard.iter_cnt {
79+
let runtime_row_cnt = (*runtime_row_cnt).max(1) as f64;
80+
row_cnt = runtime_row_cnt;
81+
}
82+
}
83+
}
84+
OptCostModel::<M, D>::cost(row_cnt, compute_cost, io_cost)
85+
}
86+
87+
fn compute_plan_node_cost(&self, node: &RelNode<OptRelNodeTyp>) -> Cost {
88+
self.base_model.compute_plan_node_cost(node)
89+
}
90+
}
91+
92+
impl<
93+
M: MostCommonValues + Serialize + DeserializeOwned,
94+
D: Distribution + Serialize + DeserializeOwned,
95+
> AdaptiveCostModel<M, D>
96+
{
97+
pub fn new(decay: usize, stats: BaseTableStats<M, D>) -> Self {
98+
Self {
99+
runtime_row_cnt: RuntimeAdaptionStorage::default(),
100+
base_model: OptCostModel::new(stats),
101+
decay,
102+
}
103+
}
104+
105+
pub fn get_runtime_map(&self) -> RuntimeAdaptionStorage {
106+
self.runtime_row_cnt.clone()
107+
}
108+
}

0 commit comments

Comments
 (0)