Skip to content
This repository was archived by the owner on Jan 7, 2025. It is now read-only.

Commit ee080d8

Browse files
authored
feat: add cost estimation for agg (#144)
Compute cost for aggregation. The cardinality is computed as the product of n-distinct of all the group-by columns. If there's no group by column, the output cardinality is just 1. This should fix the cardinality parity between postgres for Q14 and Q17. It also leads to a better join order in Q11, since aggregation is the child of a join. ## Misc - Add planner test for Q11. - Fixes Q14 and Q17. - Next step is to support n-distinct for string. - We may change to multi-dimension n-distinct after it's integrated.
1 parent f42a3cd commit ee080d8

File tree

6 files changed

+378
-93
lines changed

6 files changed

+378
-93
lines changed

optd-datafusion-repr/src/cost/base_cost.rs

+99-33
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use std::{collections::HashMap, sync::Arc};
22

33
use crate::plan_nodes::{
4-
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, LogOpType, OptRelNode, UnOpType,
4+
BinOpType, ColumnRefExpr, ConstantExpr, ConstantType, ExprList, LogOpType, OptRelNode, UnOpType,
55
};
66
use crate::properties::column_ref::{ColumnRefPropertyBuilder, GroupColumnRefs};
77
use crate::{
@@ -11,8 +11,8 @@ use crate::{
1111
use arrow_schema::{ArrowError, DataType};
1212
use datafusion::arrow::array::{
1313
Array, BooleanArray, Date32Array, Decimal128Array, Float32Array, Float64Array, Int16Array,
14-
Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, UInt16Array,
15-
UInt32Array, UInt8Array,
14+
Int32Array, Int8Array, RecordBatch, RecordBatchIterator, RecordBatchReader, StringArray,
15+
UInt16Array, UInt32Array, UInt8Array,
1616
};
1717
use itertools::Itertools;
1818
use optd_core::{
@@ -22,6 +22,7 @@ use optd_core::{
2222
};
2323
use optd_gungnir::stats::hyperloglog::{self, HyperLogLog};
2424
use optd_gungnir::stats::tdigest::{self, TDigest};
25+
use optd_gungnir::utils::arith_encoder;
2526
use serde::{Deserialize, Serialize};
2627

2728
fn compute_plan_node_cost<T: RelNodeTyp, C: CostModel<T>>(
@@ -181,6 +182,7 @@ impl DataFusionPerTableStats {
181182
| DataType::UInt32
182183
| DataType::Float32
183184
| DataType::Float64
185+
| DataType::Utf8
184186
)
185187
}
186188

@@ -222,6 +224,10 @@ impl DataFusionPerTableStats {
222224
val as f64
223225
}
224226

227+
fn str_to_f64(string: &str) -> f64 {
228+
arith_encoder::encode(string)
229+
}
230+
225231
match col_type {
226232
DataType::Boolean => {
227233
generate_stats_for_col!({ col, distr, hll, BooleanArray, to_f64_safe })
@@ -256,6 +262,9 @@ impl DataFusionPerTableStats {
256262
DataType::Decimal128(_, _) => {
257263
generate_stats_for_col!({ col, distr, hll, Decimal128Array, i128_to_f64 })
258264
}
265+
DataType::Utf8 => {
266+
generate_stats_for_col!({ col, distr, hll, StringArray, str_to_f64 })
267+
}
259268
_ => unreachable!(),
260269
}
261270
}
@@ -323,6 +332,10 @@ const DEFAULT_EQ_SEL: f64 = 0.005;
323332
const DEFAULT_INEQ_SEL: f64 = 0.3333333333333333;
324333
// Default selectivity estimate for pattern-match operators such as LIKE
325334
const DEFAULT_MATCH_SEL: f64 = 0.005;
335+
// Default selectivity if we have no information
336+
const DEFAULT_UNK_SEL: f64 = 0.005;
337+
// Default n-distinct estimate for derived columns or columns lacking statistics
338+
const DEFAULT_N_DISTINCT: u64 = 200;
326339

327340
const INVALID_SEL: f64 = 0.01;
328341

@@ -401,37 +414,33 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
401414
OptRelNodeTyp::PhysicalEmptyRelation => Self::cost(0.5, 0.01, 0.0),
402415
OptRelNodeTyp::PhysicalLimit => {
403416
let (row_cnt, compute_cost, _) = Self::cost_tuple(&children[0]);
404-
let row_cnt = if let Some(context) = context {
405-
if let Some(optimizer) = optimizer {
406-
let mut fetch_expr =
407-
optimizer.get_all_group_bindings(context.children_group_ids[2], false);
408-
assert!(
409-
fetch_expr.len() == 1,
410-
"fetch expression should be the only expr in the group"
411-
);
412-
let fetch_expr = fetch_expr.pop().unwrap();
413-
assert!(
414-
matches!(
415-
fetch_expr.typ,
416-
OptRelNodeTyp::Constant(ConstantType::UInt64)
417-
),
418-
"fetch type can only be UInt64"
419-
);
420-
let fetch = ConstantExpr::from_rel_node(fetch_expr)
421-
.unwrap()
422-
.value()
423-
.as_u64();
424-
// u64::MAX represents None
425-
if fetch == u64::MAX {
426-
row_cnt
427-
} else {
428-
row_cnt.min(fetch as f64)
429-
}
417+
let row_cnt = if let (Some(context), Some(optimizer)) = (context, optimizer) {
418+
let mut fetch_expr =
419+
optimizer.get_all_group_bindings(context.children_group_ids[2], false);
420+
assert!(
421+
fetch_expr.len() == 1,
422+
"fetch expression should be the only expr in the group"
423+
);
424+
let fetch_expr = fetch_expr.pop().unwrap();
425+
assert!(
426+
matches!(
427+
fetch_expr.typ,
428+
OptRelNodeTyp::Constant(ConstantType::UInt64)
429+
),
430+
"fetch type can only be UInt64"
431+
);
432+
let fetch = ConstantExpr::from_rel_node(fetch_expr)
433+
.unwrap()
434+
.value()
435+
.as_u64();
436+
// u64::MAX represents None
437+
if fetch == u64::MAX {
438+
row_cnt
430439
} else {
431-
panic!("compute_cost() should not be called if optimizer is None")
440+
row_cnt.min(fetch as f64)
432441
}
433442
} else {
434-
panic!("compute_cost() should not be called if context is None")
443+
(row_cnt * DEFAULT_UNK_SEL).max(1.0)
435444
};
436445
Self::cost(row_cnt, compute_cost, 0.0)
437446
}
@@ -499,10 +508,15 @@ impl<M: MostCommonValues, D: Distribution> CostModel<OptRelNodeTyp> for OptCostM
499508
Self::cost(row_cnt, row_cnt * row_cnt.ln_1p().max(1.0), 0.0)
500509
}
501510
OptRelNodeTyp::PhysicalAgg => {
502-
let (row_cnt, _, _) = Self::cost_tuple(&children[0]);
511+
let child_row_cnt = Self::row_cnt(&children[0]);
512+
let row_cnt = self.get_agg_row_cnt(context, optimizer, child_row_cnt);
503513
let (_, compute_cost_1, _) = Self::cost_tuple(&children[1]);
504514
let (_, compute_cost_2, _) = Self::cost_tuple(&children[2]);
505-
Self::cost(row_cnt, row_cnt * (compute_cost_1 + compute_cost_2), 0.0)
515+
Self::cost(
516+
row_cnt,
517+
child_row_cnt * (compute_cost_1 + compute_cost_2),
518+
0.0,
519+
)
506520
}
507521
OptRelNodeTyp::List => {
508522
let compute_cost = children
@@ -544,6 +558,58 @@ impl<M: MostCommonValues, D: Distribution> OptCostModel<M, D> {
544558
}
545559
}
546560

561+
fn get_agg_row_cnt(
562+
&self,
563+
context: Option<RelNodeContext>,
564+
optimizer: Option<&CascadesOptimizer<OptRelNodeTyp>>,
565+
child_row_cnt: f64,
566+
) -> f64 {
567+
if let (Some(context), Some(optimizer)) = (context, optimizer) {
568+
let group_by_id = context.children_group_ids[2];
569+
let mut group_by_exprs: Vec<Arc<RelNode<OptRelNodeTyp>>> =
570+
optimizer.get_all_group_bindings(group_by_id, false);
571+
assert!(
572+
group_by_exprs.len() == 1,
573+
"ExprList expression should be the only expression in the GROUP BY group"
574+
);
575+
let group_by = group_by_exprs.pop().unwrap();
576+
let group_by = ExprList::from_rel_node(group_by).unwrap();
577+
if group_by.is_empty() {
578+
1.0
579+
} else {
580+
// Multiply the n-distinct of all the group by columns.
581+
// TODO: improve with multi-dimensional n-distinct
582+
let base_table_col_refs = optimizer
583+
.get_property_by_group::<ColumnRefPropertyBuilder>(context.group_id, 1);
584+
base_table_col_refs
585+
.iter()
586+
.take(group_by.len())
587+
.map(|col_ref| match col_ref {
588+
ColumnRef::BaseTableColumnRef { table, col_idx } => {
589+
let table_stats = self.per_table_stats_map.get(table);
590+
let column_stats = table_stats.map(|table_stats| {
591+
table_stats.per_column_stats_vec.get(*col_idx).unwrap()
592+
});
593+
594+
if let Some(Some(column_stats)) = column_stats {
595+
column_stats.ndistinct as f64
596+
} else {
597+
// The column type is not supported or stats are missing.
598+
DEFAULT_N_DISTINCT as f64
599+
}
600+
}
601+
ColumnRef::Derived => DEFAULT_N_DISTINCT as f64,
602+
_ => panic!(
603+
"GROUP BY base table column ref must either be derived or base table"
604+
),
605+
})
606+
.product()
607+
}
608+
} else {
609+
(child_row_cnt * DEFAULT_UNK_SEL).max(1.0)
610+
}
611+
}
612+
547613
/// The expr_tree input must be a "mixed expression tree"
548614
/// An "expression node" refers to a RelNode that returns true for is_expression()
549615
/// A "full expression tree" is where every node in the tree is an expression node

optd-datafusion-repr/src/properties/schema.rs

+25-9
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,26 @@ use std::sync::Arc;
44
use optd_core::property::PropertyBuilder;
55

66
use super::DEFAULT_NAME;
7-
use crate::plan_nodes::{ConstantType, EmptyRelationData, OptRelNodeTyp};
7+
use crate::plan_nodes::{ConstantType, EmptyRelationData, FuncType, OptRelNodeTyp};
88

99
#[derive(Clone, Debug, Serialize, Deserialize)]
1010
pub struct Field {
1111
pub name: String,
1212
pub typ: ConstantType,
1313
pub nullable: bool,
1414
}
15+
16+
impl Field {
17+
/// Generate a field that is only a place holder whose members are never used.
18+
fn placeholder() -> Self {
19+
Self {
20+
name: DEFAULT_NAME.to_string(),
21+
typ: ConstantType::Any,
22+
nullable: true,
23+
}
24+
}
25+
}
26+
1527
#[derive(Clone, Debug, Serialize, Deserialize)]
1628
pub struct Schema {
1729
pub fields: Vec<Field>,
@@ -87,14 +99,18 @@ impl PropertyBuilder<OptRelNodeTyp> for SchemaPropertyBuilder {
8799
Schema { fields }
88100
}
89101
OptRelNodeTyp::LogOp(_) => Schema {
90-
fields: vec![
91-
Field {
92-
name: DEFAULT_NAME.to_string(),
93-
typ: ConstantType::Any,
94-
nullable: true
95-
};
96-
children.len()
97-
],
102+
fields: vec![Field::placeholder(); children.len()],
103+
},
104+
OptRelNodeTyp::Agg => {
105+
let mut group_by_schema = children[1].clone();
106+
let agg_schema = children[2].clone();
107+
group_by_schema.fields.extend(agg_schema.fields);
108+
group_by_schema
109+
}
110+
OptRelNodeTyp::Func(FuncType::Agg(_)) => Schema {
111+
// TODO: this is just a place holder now.
112+
// The real type should be the column type.
113+
fields: vec![Field::placeholder()],
98114
},
99115
_ => Schema { fields: vec![] },
100116
}

optd-gungnir/src/stats/hyperloglog.rs

+8-1
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,17 @@ pub struct HyperLogLog {
2525
alpha: f64, // The normal HLL multiplier factor.
2626
}
2727

28+
// Serialize common data types for hashing (&str).
29+
impl ByteSerializable for &str {
30+
fn to_bytes(&self) -> Vec<u8> {
31+
self.as_bytes().to_vec()
32+
}
33+
}
34+
2835
// Serialize common data types for hashing (String).
2936
impl ByteSerializable for String {
3037
fn to_bytes(&self) -> Vec<u8> {
31-
self.as_bytes().to_vec()
38+
self.as_str().to_bytes()
3239
}
3340
}
3441

optd-sqlplannertest/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ impl DatafusionDBMS {
140140
task: &str,
141141
flags: &[String],
142142
) -> Result<()> {
143-
println!("task_explain(): called on sql={}", sql);
144143
use std::fmt::Write;
145144

146145
let with_logical = flags.contains(&"with_logical".to_string());

0 commit comments

Comments
 (0)